sql_splitter/sample/
mod.rs

1//! Sample command for creating reduced datasets from SQL dumps.
2//!
3//! The sample command creates reduced datasets while optionally preserving
4//! foreign key integrity through dependency-aware FK chain resolution.
5//!
6//! Supports MySQL, PostgreSQL, and SQLite dialects.
7
8mod config;
9mod reservoir;
10
11pub use config::{DefaultClassifier, GlobalTableMode, SampleYamlConfig, TableClassification};
12pub use reservoir::Reservoir;
13
14use crate::parser::mysql_insert::{parse_mysql_insert_rows, ParsedRow, PkSet};
15use crate::parser::postgres_copy::{parse_copy_columns, parse_postgres_copy_rows, ParsedCopyRow};
16use crate::parser::{ContentFilter, Parser, SqlDialect, StatementType};
17use crate::schema::{SchemaBuilder, SchemaGraph, TableId};
18use crate::splitter::Splitter;
19use ahash::AHashMap;
20use indicatif::{ProgressBar, ProgressStyle};
21use rand::rngs::StdRng;
22use rand::{Rng, SeedableRng};
23use std::fs::{self, File};
24use std::io::{BufWriter, Write};
25use std::path::{Path, PathBuf};
26use tempfile::TempDir;
27
28/// Sampling mode
29#[derive(Debug, Clone, Copy)]
30pub enum SampleMode {
31    /// Sample N% of rows from each table
32    Percent(u32),
33    /// Sample up to N rows from each table
34    Rows(usize),
35}
36
37/// Configuration for the sample command
38#[derive(Debug)]
39pub struct SampleConfig {
40    /// Input SQL file
41    pub input: PathBuf,
42    /// Output SQL file (None for stdout)
43    pub output: Option<PathBuf>,
44    /// SQL dialect
45    pub dialect: SqlDialect,
46    /// Sampling mode
47    pub mode: SampleMode,
48    /// Preserve foreign key relationships
49    pub preserve_relations: bool,
50    /// Only sample these tables (None = all)
51    pub tables_filter: Option<Vec<String>>,
52    /// Exclude these tables
53    pub exclude: Vec<String>,
54    /// Root tables for sampling (start from these)
55    pub root_tables: Vec<String>,
56    /// How to handle global/lookup tables
57    pub include_global: GlobalTableMode,
58    /// Random seed for reproducibility
59    pub seed: u64,
60    /// Dry run mode (show stats only)
61    pub dry_run: bool,
62    /// Show progress
63    pub progress: bool,
64    /// YAML config file path
65    pub config_file: Option<PathBuf>,
66    /// Maximum total rows to sample (explosion guard)
67    pub max_total_rows: Option<usize>,
68    /// Fail if any FK integrity issues detected
69    pub strict_fk: bool,
70    /// Include schema statements in output
71    pub include_schema: bool,
72}
73
74impl Default for SampleConfig {
75    fn default() -> Self {
76        Self {
77            input: PathBuf::new(),
78            output: None,
79            dialect: SqlDialect::MySql,
80            mode: SampleMode::Percent(10),
81            preserve_relations: false,
82            tables_filter: None,
83            exclude: Vec::new(),
84            root_tables: Vec::new(),
85            include_global: GlobalTableMode::Lookups,
86            seed: rand::random(),
87            dry_run: false,
88            progress: false,
89            config_file: None,
90            max_total_rows: None,
91            strict_fk: false,
92            include_schema: true,
93        }
94    }
95}
96
97/// Statistics from sample operation
98#[derive(Debug, Default)]
99pub struct SampleStats {
100    /// Number of tables sampled
101    pub tables_sampled: usize,
102    /// Number of tables skipped
103    pub tables_skipped: usize,
104    /// Total rows selected
105    pub total_rows_selected: u64,
106    /// Total rows seen
107    pub total_rows_seen: u64,
108    /// Per-table statistics
109    pub table_stats: Vec<TableSampleStats>,
110    /// Warning messages
111    pub warnings: Vec<String>,
112    /// FK orphan count (rows rejected due to missing parents)
113    pub fk_orphans_rejected: u64,
114}
115
116/// Per-table sampling statistics
117#[derive(Debug, Clone)]
118pub struct TableSampleStats {
119    pub name: String,
120    pub rows_seen: u64,
121    pub rows_selected: u64,
122    pub classification: TableClassification,
123}
124
125/// Runtime state for a table during sampling
126struct TableRuntime {
127    /// Table name
128    name: String,
129    /// Selected rows with format metadata
130    selected_rows: Vec<SelectedRow>,
131    /// Primary key set for FK membership checks
132    pk_set: PkSet,
133    /// Rows seen count
134    rows_seen: u64,
135    /// Whether to skip this table
136    skip: bool,
137    /// Table classification
138    classification: TableClassification,
139    /// FK orphans rejected for this table
140    fk_orphans: u64,
141}
142
143/// Combined row representation for both MySQL INSERT and PostgreSQL COPY
144enum UnifiedRow {
145    Insert(ParsedRow),
146    Copy(ParsedCopyRow),
147}
148
149/// Row format indicator for output
150#[derive(Debug, Clone, Copy, PartialEq)]
151enum RowFormat {
152    Insert,
153    Copy,
154}
155
156/// Selected row with format metadata
157struct SelectedRow {
158    raw: Vec<u8>,
159    format: RowFormat,
160}
161
162impl UnifiedRow {
163    fn pk(&self) -> Option<&smallvec::SmallVec<[crate::parser::mysql_insert::PkValue; 2]>> {
164        match self {
165            UnifiedRow::Insert(r) => r.pk.as_ref(),
166            UnifiedRow::Copy(r) => r.pk.as_ref(),
167        }
168    }
169
170    fn fk_values(
171        &self,
172    ) -> &[(
173        crate::parser::mysql_insert::FkRef,
174        smallvec::SmallVec<[crate::parser::mysql_insert::PkValue; 2]>,
175    )] {
176        match self {
177            UnifiedRow::Insert(r) => &r.fk_values,
178            UnifiedRow::Copy(r) => &r.fk_values,
179        }
180    }
181
182    fn into_selected(self) -> SelectedRow {
183        match self {
184            UnifiedRow::Insert(r) => SelectedRow {
185                raw: r.raw,
186                format: RowFormat::Insert,
187            },
188            UnifiedRow::Copy(r) => SelectedRow {
189                raw: r.raw,
190                format: RowFormat::Copy,
191            },
192        }
193    }
194}
195
196/// Run the sample command
197pub fn run(config: SampleConfig) -> anyhow::Result<SampleStats> {
198    // Load YAML config if provided
199    let yaml_config = if let Some(ref path) = config.config_file {
200        Some(SampleYamlConfig::load(path)?)
201    } else {
202        None
203    };
204
205    let mut rng = StdRng::seed_from_u64(config.seed);
206    let mut stats = SampleStats::default();
207
208    // Get file size for progress tracking
209    let file_size = std::fs::metadata(&config.input)?.len();
210
211    // Progress bar setup - byte-based for the split phase
212    let progress_bar = if config.progress {
213        let pb = ProgressBar::new(file_size);
214        pb.set_style(
215            ProgressStyle::with_template(
216                "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%) {msg}",
217            )
218            .unwrap()
219            .progress_chars("█▓▒░  ")
220            .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
221        );
222        pb.enable_steady_tick(std::time::Duration::from_millis(100));
223        pb.set_message("Splitting dump...");
224        Some(pb)
225    } else {
226        None
227    };
228
229    // Phase 0: Split into temp per-table files
230    let temp_dir = TempDir::new()?;
231    let tables_dir = temp_dir.path().join("tables");
232
233    let mut splitter = Splitter::new(config.input.clone(), tables_dir.clone())
234        .with_dialect(config.dialect)
235        .with_content_filter(ContentFilter::All);
236
237    if let Some(ref pb) = progress_bar {
238        let pb_clone = pb.clone();
239        splitter = splitter.with_progress(move |bytes| {
240            pb_clone.set_position(bytes);
241        });
242    }
243
244    let split_stats = splitter.split()?;
245
246    // Finish byte-based progress, switch to milestone messages
247    if let Some(ref pb) = progress_bar {
248        pb.finish_and_clear();
249    }
250
251    if config.progress {
252        eprintln!(
253            "Split complete: {} tables, {} statements",
254            split_stats.tables_found, split_stats.statements_processed
255        );
256    }
257
258    // Phase 1: Build schema graph
259    if config.progress {
260        eprintln!("Building schema graph...");
261    }
262
263    let graph = build_schema_graph(&tables_dir, &config)?;
264
265    let (topo_order, cyclic_tables) = graph.processing_order();
266
267    if !cyclic_tables.is_empty() {
268        let names: Vec<_> = cyclic_tables
269            .iter()
270            .filter_map(|&id| graph.table_name(id))
271            .collect();
272        let msg = format!(
273            "Warning: {} tables have FK cycles (intra-cycle FK enforcement disabled): {:?}",
274            cyclic_tables.len(),
275            names
276        );
277        if config.progress {
278            eprintln!("{}", msg);
279        }
280        stats.warnings.push(msg);
281    }
282
283    // Build set of cyclic table IDs for quick lookup
284    let cyclic_set: ahash::AHashSet<TableId> = cyclic_tables.iter().copied().collect();
285
286    // Determine root tables
287    let explicit_roots: ahash::AHashSet<String> = config
288        .root_tables
289        .iter()
290        .map(|s| s.to_lowercase())
291        .collect();
292
293    // Initialize table runtimes with classification
294    let mut runtimes: AHashMap<TableId, TableRuntime> = AHashMap::new();
295    for table in graph.schema.iter() {
296        let classification =
297            determine_classification(&table.name, &graph, table.id, &yaml_config, &explicit_roots);
298        let skip = should_skip_table(&table.name, &config, &yaml_config, classification);
299
300        runtimes.insert(
301            table.id,
302            TableRuntime {
303                name: table.name.clone(),
304                selected_rows: Vec::new(),
305                pk_set: PkSet::default(),
306                rows_seen: 0,
307                skip,
308                classification,
309                fk_orphans: 0,
310            },
311        );
312    }
313
314    // Phase 2: Process tables in dependency order
315    if config.progress {
316        eprintln!("Sampling {} tables in dependency order...", topo_order.len());
317    }
318
319    // Process acyclic tables first, then cyclic tables
320    let all_tables: Vec<TableId> = topo_order.into_iter().chain(cyclic_tables).collect();
321
322    let mut total_selected: u64 = 0;
323
324    for table_id in &all_tables {
325        let table_schema = match graph.schema.table(*table_id) {
326            Some(s) => s,
327            None => continue,
328        };
329
330        // Check if we should skip this table
331        let (should_skip, table_name, classification) = {
332            let runtime = match runtimes.get(table_id) {
333                Some(r) => r,
334                None => continue,
335            };
336            (runtime.skip, runtime.name.clone(), runtime.classification)
337        };
338
339        if should_skip {
340            stats.tables_skipped += 1;
341            continue;
342        }
343
344        // Handle lookup/global tables specially
345        let sample_mode = match classification {
346            TableClassification::Lookup => {
347                match config.include_global {
348                    GlobalTableMode::None => {
349                        stats.tables_skipped += 1;
350                        continue;
351                    }
352                    GlobalTableMode::Lookups | GlobalTableMode::All => {
353                        // Include all rows
354                        SampleMode::Percent(100)
355                    }
356                }
357            }
358            TableClassification::System => {
359                stats.tables_skipped += 1;
360                continue;
361            }
362            _ => get_table_sample_mode(&table_name, &config, &yaml_config),
363        };
364
365        let table_file = tables_dir.join(format!("{}.sql", table_name));
366        if !table_file.exists() {
367            continue;
368        }
369
370        // Parse statements from this table file
371        let file = File::open(&table_file)?;
372        let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
373
374        // Collect rows based on sampling mode
375        let mut candidates: Vec<UnifiedRow> = Vec::new();
376        let mut rows_seen = 0u64;
377        let mut fk_orphans = 0u64;
378
379        // For PostgreSQL COPY, track the current column order
380        let mut copy_columns: Vec<String> = Vec::new();
381
382        while let Some(stmt) = parser.read_statement()? {
383            let (stmt_type, _) =
384                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
385
386            match stmt_type {
387                StatementType::Insert => {
388                    let rows = parse_mysql_insert_rows(&stmt, table_schema)?;
389
390                    for row in rows {
391                        rows_seen += 1;
392                        let unified = UnifiedRow::Insert(row);
393
394                        if config.preserve_relations {
395                            let (passes, orphan) = check_unified_fk_membership(
396                                &unified,
397                                table_schema,
398                                &runtimes,
399                                &cyclic_set,
400                                table_id,
401                            );
402                            if !passes {
403                                fk_orphans += 1;
404                                if orphan && config.strict_fk {
405                                    anyhow::bail!(
406                                        "FK integrity violation in table '{}': row references missing parent",
407                                        table_name
408                                    );
409                                }
410                                continue;
411                            }
412                        }
413
414                        candidates.push(unified);
415                    }
416                }
417                StatementType::Copy => {
418                    // Extract column order from COPY header
419                    let header = String::from_utf8_lossy(&stmt);
420                    copy_columns = parse_copy_columns(&header);
421                }
422                StatementType::Unknown if config.dialect == SqlDialect::Postgres => {
423                    // This might be COPY data
424                    if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
425                        let rows =
426                            parse_postgres_copy_rows(&stmt, table_schema, copy_columns.clone())?;
427
428                        for row in rows {
429                            rows_seen += 1;
430                            let unified = UnifiedRow::Copy(row);
431
432                            if config.preserve_relations {
433                                let (passes, orphan) = check_unified_fk_membership(
434                                    &unified,
435                                    table_schema,
436                                    &runtimes,
437                                    &cyclic_set,
438                                    table_id,
439                                );
440                                if !passes {
441                                    fk_orphans += 1;
442                                    if orphan && config.strict_fk {
443                                        anyhow::bail!(
444                                            "FK integrity violation in table '{}': row references missing parent",
445                                            table_name
446                                        );
447                                    }
448                                    continue;
449                                }
450                            }
451
452                            candidates.push(unified);
453                        }
454                    }
455                }
456                _ => {}
457            }
458        }
459
460        // Check max_total_rows guard
461        if let Some(max) = config.max_total_rows {
462            if total_selected + candidates.len() as u64 > max as u64 {
463                let msg = format!(
464                    "Warning: Reached max_total_rows limit ({}) at table '{}'",
465                    max, table_name
466                );
467                stats.warnings.push(msg);
468                break;
469            }
470        }
471
472        // Apply sampling to candidates
473        let selected = sample_rows(&candidates, sample_mode, &mut rng);
474
475        // Update total count
476        total_selected += selected.len() as u64;
477
478        // Store selected rows and update PK set
479        let runtime = runtimes.get_mut(table_id).unwrap();
480        runtime.rows_seen = rows_seen;
481        runtime.fk_orphans = fk_orphans;
482
483        for row in selected {
484            if let Some(pk) = row.pk() {
485                runtime.pk_set.insert(pk.clone());
486            }
487            runtime.selected_rows.push(row.into_selected());
488        }
489
490        stats.fk_orphans_rejected += fk_orphans;
491
492        stats.table_stats.push(TableSampleStats {
493            name: runtime.name.clone(),
494            rows_seen: runtime.rows_seen,
495            rows_selected: runtime.selected_rows.len() as u64,
496            classification: runtime.classification,
497        });
498    }
499
500    // Calculate totals
501    for table_stats in &stats.table_stats {
502        stats.total_rows_seen += table_stats.rows_seen;
503        stats.total_rows_selected += table_stats.rows_selected;
504    }
505    stats.tables_sampled = stats.table_stats.len();
506
507    if config.progress {
508        eprintln!("Sampling complete");
509    }
510
511    // Phase 3: Output synthesis
512    if config.dry_run {
513        return Ok(stats);
514    }
515
516    if config.progress {
517        eprintln!("Writing output...");
518    }
519
520    write_output(&config, &graph, &all_tables, &runtimes, &tables_dir, &stats)?;
521
522    Ok(stats)
523}
524
525/// Build schema graph from split table files
526fn build_schema_graph(tables_dir: &Path, config: &SampleConfig) -> anyhow::Result<SchemaGraph> {
527    let mut builder = SchemaBuilder::new();
528
529    for entry in fs::read_dir(tables_dir)? {
530        let entry = entry?;
531        let path = entry.path();
532
533        if path.extension().map(|e| e == "sql").unwrap_or(false) {
534            let file = File::open(&path)?;
535            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
536
537            while let Some(stmt) = parser.read_statement()? {
538                let stmt_str = String::from_utf8_lossy(&stmt);
539                let (stmt_type, _) =
540                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
541
542                match stmt_type {
543                    StatementType::CreateTable => {
544                        builder.parse_create_table(&stmt_str);
545                    }
546                    StatementType::AlterTable => {
547                        builder.parse_alter_table(&stmt_str);
548                    }
549                    _ => {}
550                }
551            }
552        }
553    }
554
555    Ok(SchemaGraph::from_schema(builder.build()))
556}
557
558/// Determine table classification
559fn determine_classification(
560    name: &str,
561    graph: &SchemaGraph,
562    table_id: TableId,
563    yaml_config: &Option<SampleYamlConfig>,
564    explicit_roots: &ahash::AHashSet<String>,
565) -> TableClassification {
566    // Check explicit roots first
567    if explicit_roots.contains(&name.to_lowercase()) {
568        return TableClassification::Root;
569    }
570
571    // Check YAML config
572    if let Some(ref config) = yaml_config {
573        let class = config.get_classification(name);
574        if class != TableClassification::Normal {
575            return class;
576        }
577    }
578
579    // Check if it's a graph root (no parents)
580    if graph.parents[table_id.0 as usize].is_empty() {
581        return TableClassification::Root;
582    }
583
584    // Use default classifier
585    DefaultClassifier::classify(name)
586}
587
588/// Check if a table should be skipped
589fn should_skip_table(
590    name: &str,
591    config: &SampleConfig,
592    yaml_config: &Option<SampleYamlConfig>,
593    classification: TableClassification,
594) -> bool {
595    let name_lower = name.to_lowercase();
596
597    // Check exclude list
598    if config
599        .exclude
600        .iter()
601        .any(|e| e.to_lowercase() == name_lower)
602    {
603        return true;
604    }
605
606    // Check YAML skip
607    if let Some(ref yc) = yaml_config {
608        if yc.should_skip(name) {
609            return true;
610        }
611    }
612
613    // Check include filter
614    if let Some(ref filter) = config.tables_filter {
615        if !filter.iter().any(|f| f.to_lowercase() == name_lower) {
616            return true;
617        }
618    }
619
620    // Skip system tables by default
621    if classification == TableClassification::System {
622        return true;
623    }
624
625    false
626}
627
628/// Get sample mode for a specific table
629fn get_table_sample_mode(
630    name: &str,
631    config: &SampleConfig,
632    yaml_config: &Option<SampleYamlConfig>,
633) -> SampleMode {
634    // Check YAML config first
635    if let Some(ref yc) = yaml_config {
636        if let Some(rows) = yc.get_rows(name) {
637            return SampleMode::Rows(rows);
638        }
639        if let Some(percent) = yc.get_percent(name) {
640            return SampleMode::Percent(percent);
641        }
642    }
643
644    // Fall back to global config
645    config.mode
646}
647
648/// Check FK membership for a unified row (works with both INSERT and COPY rows)
649fn check_unified_fk_membership(
650    row: &UnifiedRow,
651    table_schema: &crate::schema::TableSchema,
652    runtimes: &AHashMap<TableId, TableRuntime>,
653    cyclic_set: &ahash::AHashSet<TableId>,
654    current_table_id: &TableId,
655) -> (bool, bool) {
656    let mut passes = true;
657    let mut is_orphan = false;
658
659    for (fk_ref, fk_tuple) in row.fk_values() {
660        if let Some(fk) = table_schema.foreign_keys.get(fk_ref.fk_index as usize) {
661            if let Some(parent_id) = fk.referenced_table_id {
662                // Skip FK check for cyclic tables
663                if cyclic_set.contains(&parent_id) && cyclic_set.contains(current_table_id) {
664                    continue;
665                }
666
667                // Check if parent row exists in parent's pk_set
668                if let Some(parent_runtime) = runtimes.get(&parent_id) {
669                    if !parent_runtime.pk_set.contains(fk_tuple) {
670                        passes = false;
671                        is_orphan = true;
672                        break;
673                    }
674                }
675            }
676        }
677    }
678
679    (passes, is_orphan)
680}
681
682/// Sample rows according to sampling mode
683fn sample_rows(candidates: &[UnifiedRow], mode: SampleMode, rng: &mut StdRng) -> Vec<UnifiedRow> {
684    match mode {
685        SampleMode::Percent(p) => {
686            // Bernoulli sampling
687            let prob = p as f64 / 100.0;
688            candidates
689                .iter()
690                .filter(|_| rng.gen_bool(prob.min(1.0)))
691                .map(|r| match r {
692                    UnifiedRow::Insert(row) => UnifiedRow::Insert(row.clone()),
693                    UnifiedRow::Copy(row) => UnifiedRow::Copy(row.clone()),
694                })
695                .collect()
696        }
697        SampleMode::Rows(n) => {
698            // Reservoir sampling
699            let mut reservoir = Reservoir::new(n, StdRng::from_rng(rng).unwrap());
700            for row in candidates {
701                let cloned = match row {
702                    UnifiedRow::Insert(r) => UnifiedRow::Insert(r.clone()),
703                    UnifiedRow::Copy(r) => UnifiedRow::Copy(r.clone()),
704                };
705                reservoir.consider(cloned);
706            }
707            reservoir.into_items()
708        }
709    }
710}
711
712/// Write sampled output
713fn write_output(
714    config: &SampleConfig,
715    _graph: &SchemaGraph,
716    table_order: &[TableId],
717    runtimes: &AHashMap<TableId, TableRuntime>,
718    tables_dir: &Path,
719    stats: &SampleStats,
720) -> anyhow::Result<()> {
721    let mut writer: Box<dyn Write> = match &config.output {
722        Some(path) => {
723            if let Some(parent) = path.parent() {
724                fs::create_dir_all(parent)?;
725            }
726            Box::new(BufWriter::with_capacity(256 * 1024, File::create(path)?))
727        }
728        None => Box::new(BufWriter::new(std::io::stdout())),
729    };
730
731    // Write header comment
732    write_header(&mut writer, config, stats)?;
733
734    // Write dialect-specific header
735    write_dialect_header(&mut writer, config.dialect)?;
736
737    // Write schema for each table (if enabled)
738    if config.include_schema {
739        for &table_id in table_order {
740            let runtime = match runtimes.get(&table_id) {
741                Some(r) if !r.skip && !r.selected_rows.is_empty() => r,
742                _ => continue,
743            };
744
745            let table_file = tables_dir.join(format!("{}.sql", runtime.name));
746            if !table_file.exists() {
747                continue;
748            }
749
750            // Copy schema statements from table file
751            let file = File::open(&table_file)?;
752            let mut parser = Parser::with_dialect(file, 64 * 1024, config.dialect);
753
754            while let Some(stmt) = parser.read_statement()? {
755                let (stmt_type, _) =
756                    Parser::<&[u8]>::parse_statement_with_dialect(&stmt, config.dialect);
757
758                if stmt_type.is_schema() {
759                    writer.write_all(&stmt)?;
760                    writer.write_all(b"\n")?;
761                }
762            }
763        }
764    }
765
766    // Write data for each table
767    for &table_id in table_order {
768        let runtime = match runtimes.get(&table_id) {
769            Some(r) if !r.skip && !r.selected_rows.is_empty() => r,
770            _ => continue,
771        };
772
773        let table_name = &runtime.name;
774        let row_count = runtime.selected_rows.len();
775
776        writeln!(writer, "\n-- Data: {} ({} rows)", table_name, row_count)?;
777
778        // Write INSERTs in chunks (compact multi-row format)
779        const CHUNK_SIZE: usize = 1000;
780
781        // Get the table name quoting based on dialect
782        let quoted_name = match config.dialect {
783            SqlDialect::MySql => format!("`{}`", table_name),
784            SqlDialect::Postgres | SqlDialect::Sqlite => format!("\"{}\"", table_name),
785        };
786
787        for chunk in runtime.selected_rows.chunks(CHUNK_SIZE) {
788            writeln!(writer, "INSERT INTO {} VALUES", quoted_name)?;
789
790            for (i, row) in chunk.iter().enumerate() {
791                if i > 0 {
792                    writer.write_all(b",\n")?;
793                }
794
795                // Convert row to INSERT VALUES format based on original format
796                let values = match row.format {
797                    RowFormat::Insert => {
798                        // Already in INSERT format, but may need dialect conversion
799                        match config.dialect {
800                            SqlDialect::Postgres => convert_row_to_postgres(&row.raw),
801                            _ => row.raw.clone(),
802                        }
803                    }
804                    RowFormat::Copy => {
805                        // Convert COPY format to INSERT VALUES
806                        convert_copy_to_insert_values(&row.raw, config.dialect)
807                    }
808                };
809                writer.write_all(&values)?;
810            }
811
812            writer.write_all(b";\n")?;
813        }
814    }
815
816    // Write dialect-specific footer
817    write_dialect_footer(&mut writer, config.dialect)?;
818
819    writer.flush()?;
820
821    Ok(())
822}
823
824/// Write header comment
825fn write_header<W: Write>(
826    writer: &mut W,
827    config: &SampleConfig,
828    stats: &SampleStats,
829) -> std::io::Result<()> {
830    writeln!(writer, "-- Sampled from: {}", config.input.display())?;
831    writeln!(
832        writer,
833        "-- Date: {}",
834        chrono::Local::now().format("%Y-%m-%d %H:%M:%S")
835    )?;
836    writeln!(
837        writer,
838        "-- Mode: {:?}{}",
839        config.mode,
840        if config.preserve_relations {
841            ", preserve-relations"
842        } else {
843            ""
844        }
845    )?;
846    writeln!(writer, "-- Seed: {}", config.seed)?;
847    writeln!(writer, "-- Dialect: {}", config.dialect)?;
848    writeln!(writer, "--")?;
849    writeln!(writer, "-- Statistics:")?;
850    writeln!(writer, "--   Tables sampled: {}", stats.tables_sampled)?;
851    writeln!(writer, "--   Tables skipped: {}", stats.tables_skipped)?;
852
853    let percent = if stats.total_rows_seen > 0 {
854        (stats.total_rows_selected as f64 / stats.total_rows_seen as f64) * 100.0
855    } else {
856        0.0
857    };
858    writeln!(
859        writer,
860        "--   Total rows: {} (from {} original, {:.1}%)",
861        stats.total_rows_selected, stats.total_rows_seen, percent
862    )?;
863
864    if stats.fk_orphans_rejected > 0 {
865        writeln!(
866            writer,
867            "--   FK orphans rejected: {}",
868            stats.fk_orphans_rejected
869        )?;
870    }
871
872    if !stats.warnings.is_empty() {
873        writeln!(writer, "--   Warnings: {}", stats.warnings.len())?;
874    }
875
876    writeln!(writer)?;
877
878    Ok(())
879}
880
881/// Write dialect-specific header
882fn write_dialect_header<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
883    match dialect {
884        SqlDialect::MySql => {
885            writeln!(writer, "SET NAMES utf8mb4;")?;
886            writeln!(writer, "SET FOREIGN_KEY_CHECKS = 0;")?;
887        }
888        SqlDialect::Postgres => {
889            writeln!(writer, "SET client_encoding = 'UTF8';")?;
890            writeln!(writer, "SET session_replication_role = replica;")?;
891        }
892        SqlDialect::Sqlite => {
893            writeln!(writer, "PRAGMA foreign_keys = OFF;")?;
894        }
895    }
896    writeln!(writer)?;
897    Ok(())
898}
899
900/// Write dialect-specific footer
901fn write_dialect_footer<W: Write>(writer: &mut W, dialect: SqlDialect) -> std::io::Result<()> {
902    writeln!(writer)?;
903    match dialect {
904        SqlDialect::MySql => {
905            writeln!(writer, "SET FOREIGN_KEY_CHECKS = 1;")?;
906        }
907        SqlDialect::Postgres => {
908            writeln!(writer, "SET session_replication_role = DEFAULT;")?;
909        }
910        SqlDialect::Sqlite => {
911            writeln!(writer, "PRAGMA foreign_keys = ON;")?;
912        }
913    }
914    Ok(())
915}
916
917/// Convert a MySQL-style row to PostgreSQL syntax
918fn convert_row_to_postgres(row: &[u8]) -> Vec<u8> {
919    // Simple conversion: just replace escaped quotes
920    // A full implementation would handle more edge cases
921    let mut result = Vec::with_capacity(row.len());
922    let mut i = 0;
923
924    while i < row.len() {
925        if row[i] == b'\\' && i + 1 < row.len() && row[i + 1] == b'\'' {
926            // MySQL: \' -> PostgreSQL: ''
927            result.push(b'\'');
928            result.push(b'\'');
929            i += 2;
930        } else {
931            result.push(row[i]);
932            i += 1;
933        }
934    }
935
936    result
937}
938
939/// Convert PostgreSQL COPY format (tab-separated) to INSERT VALUES format
940fn convert_copy_to_insert_values(row: &[u8], dialect: SqlDialect) -> Vec<u8> {
941    let mut result = Vec::with_capacity(row.len() + 20);
942    result.push(b'(');
943
944    let fields: Vec<&[u8]> = row.split(|&b| b == b'\t').collect();
945
946    for (i, field) in fields.iter().enumerate() {
947        if i > 0 {
948            result.extend_from_slice(b", ");
949        }
950
951        // Check for NULL marker
952        if *field == b"\\N" {
953            result.extend_from_slice(b"NULL");
954        } else if field.is_empty() {
955            // Empty string
956            match dialect {
957                SqlDialect::MySql => result.extend_from_slice(b"''"),
958                SqlDialect::Postgres | SqlDialect::Sqlite => result.extend_from_slice(b"''"),
959            }
960        } else if is_numeric(field) {
961            // Numeric value - no quotes needed
962            result.extend_from_slice(field);
963        } else {
964            // String value - needs quoting
965            result.push(b'\'');
966            for &b in *field {
967                match b {
968                    b'\'' => {
969                        // Escape single quote
970                        match dialect {
971                            SqlDialect::MySql => result.extend_from_slice(b"\\'"),
972                            SqlDialect::Postgres | SqlDialect::Sqlite => {
973                                result.extend_from_slice(b"''")
974                            }
975                        }
976                    }
977                    b'\\' if dialect == SqlDialect::MySql => {
978                        // Escape backslash in MySQL
979                        result.extend_from_slice(b"\\\\");
980                    }
981                    _ => result.push(b),
982                }
983            }
984            result.push(b'\'');
985        }
986    }
987
988    result.push(b')');
989    result
990}
991
992/// Check if a byte slice represents a numeric value
993fn is_numeric(s: &[u8]) -> bool {
994    if s.is_empty() {
995        return false;
996    }
997
998    let mut has_digit = false;
999    let mut has_dot = false;
1000    let mut start = 0;
1001
1002    // Handle leading sign
1003    if s[0] == b'-' || s[0] == b'+' {
1004        start = 1;
1005    }
1006
1007    for &b in &s[start..] {
1008        match b {
1009            b'0'..=b'9' => has_digit = true,
1010            b'.' if !has_dot => has_dot = true,
1011            b'e' | b'E' => {
1012                // Scientific notation - just check rest is digits
1013                continue;
1014            }
1015            _ => return false,
1016        }
1017    }
1018
1019    has_digit
1020}