use crate::parser::{
determine_buffer_size, mysql_insert, postgres_copy, Parser, SqlDialect, StatementType,
};
use crate::progress::ProgressReader;
use crate::schema::{Schema, SchemaBuilder, TableId};
use crate::splitter::Compression;
use ahash::{AHashMap, AHashSet};
use schemars::JsonSchema;
use serde::Serialize;
use std::fmt;
use std::fs::File;
use std::hash::{Hash, Hasher};
use std::io::Read;
use std::path::PathBuf;
use std::sync::Arc;
const MAX_ISSUES: usize = 1000;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum Severity {
Error,
Warning,
Info,
}
impl fmt::Display for Severity {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Severity::Error => write!(f, "ERROR"),
Severity::Warning => write!(f, "WARNING"),
Severity::Info => write!(f, "INFO"),
}
}
}
#[derive(Debug, Clone, Serialize, JsonSchema)]
pub struct Location {
#[serde(skip_serializing_if = "Option::is_none")]
pub table: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub statement_index: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub approx_line: Option<u64>,
}
impl Location {
pub fn new() -> Self {
Self {
table: None,
statement_index: None,
approx_line: None,
}
}
pub fn with_table(mut self, table: impl Into<String>) -> Self {
self.table = Some(table.into());
self
}
pub fn with_statement(mut self, index: u64) -> Self {
self.statement_index = Some(index);
self
}
#[allow(dead_code)]
pub fn with_line(mut self, line: u64) -> Self {
self.approx_line = Some(line);
self
}
}
impl Default for Location {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, JsonSchema)]
pub struct ValidationIssue {
pub code: &'static str,
pub severity: Severity,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub location: Option<Location>,
}
impl ValidationIssue {
pub fn error(code: &'static str, message: impl Into<String>) -> Self {
Self {
code,
severity: Severity::Error,
message: message.into(),
location: None,
}
}
pub fn warning(code: &'static str, message: impl Into<String>) -> Self {
Self {
code,
severity: Severity::Warning,
message: message.into(),
location: None,
}
}
pub fn info(code: &'static str, message: impl Into<String>) -> Self {
Self {
code,
severity: Severity::Info,
message: message.into(),
location: None,
}
}
pub fn with_location(mut self, location: Location) -> Self {
self.location = Some(location);
self
}
}
impl fmt::Display for ValidationIssue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} [{}]", self.severity, self.code)?;
if let Some(ref loc) = self.location {
if let Some(ref table) = loc.table {
write!(f, " table={}", table)?;
}
if let Some(stmt) = loc.statement_index {
write!(f, " stmt={}", stmt)?;
}
if let Some(line) = loc.approx_line {
write!(f, " line~{}", line)?;
}
}
write!(f, ": {}", self.message)
}
}
#[derive(Debug, Clone)]
pub struct ValidateOptions {
pub path: PathBuf,
pub dialect: Option<SqlDialect>,
pub progress: bool,
pub strict: bool,
pub json: bool,
pub max_rows_per_table: usize,
pub fk_checks_enabled: bool,
pub max_pk_fk_keys: Option<usize>,
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct ValidationSummary {
pub dialect: String,
pub issues: Vec<ValidationIssue>,
pub summary: SummaryStats,
pub checks: CheckResults,
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct SummaryStats {
pub errors: usize,
pub warnings: usize,
pub info: usize,
pub tables_scanned: usize,
pub statements_scanned: u64,
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct CheckResults {
pub syntax: CheckStatus,
pub encoding: CheckStatus,
pub ddl_dml_consistency: CheckStatus,
pub pk_duplicates: CheckStatus,
pub fk_integrity: CheckStatus,
}
#[derive(Debug, Serialize, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum CheckStatus {
Ok,
Failed(usize),
Skipped(String),
}
impl fmt::Display for CheckStatus {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CheckStatus::Ok => write!(f, "OK"),
CheckStatus::Failed(n) => write!(f, "{} issues", n),
CheckStatus::Skipped(reason) => write!(f, "Skipped ({})", reason),
}
}
}
impl ValidationSummary {
pub fn has_errors(&self) -> bool {
self.summary.errors > 0
}
pub fn has_warnings(&self) -> bool {
self.summary.warnings > 0
}
}
type PkHash = u64;
fn hash_pk_values(values: &smallvec::SmallVec<[mysql_insert::PkValue; 2]>) -> PkHash {
let mut hasher = ahash::AHasher::default();
(values.len() as u8).hash(&mut hasher);
for v in values {
match v {
mysql_insert::PkValue::Int(i) => {
0u8.hash(&mut hasher);
i.hash(&mut hasher);
}
mysql_insert::PkValue::BigInt(i) => {
1u8.hash(&mut hasher);
i.hash(&mut hasher);
}
mysql_insert::PkValue::Text(s) => {
2u8.hash(&mut hasher);
s.hash(&mut hasher);
}
mysql_insert::PkValue::Null => {
3u8.hash(&mut hasher);
}
}
}
hasher.finish()
}
struct PendingFkCheck {
child_table_id: TableId,
parent_table_id: TableId,
fk_hash: PkHash,
stmt_idx: u64,
}
struct TableState {
row_count: u64,
pk_values: Option<AHashSet<PkHash>>,
pk_column_indices: Vec<usize>,
pk_duplicates: u64,
fk_missing_parents: u64,
}
impl TableState {
fn new() -> Self {
Self {
row_count: 0,
pk_values: Some(AHashSet::new()),
pk_column_indices: Vec::new(),
pk_duplicates: 0,
fk_missing_parents: 0,
}
}
fn with_pk_columns(mut self, indices: Vec<usize>) -> Self {
self.pk_column_indices = indices;
self
}
}
pub struct Validator {
options: ValidateOptions,
issues: Vec<ValidationIssue>,
dialect: SqlDialect,
tables_from_ddl: AHashSet<String>,
tables_from_dml: Vec<(String, u64)>,
schema_builder: SchemaBuilder,
schema: Option<Schema>,
table_states: AHashMap<TableId, TableState>,
pending_fk_checks: Vec<PendingFkCheck>,
progress_fn: Option<Arc<dyn Fn(u64) + Send + Sync>>,
statement_count: u64,
syntax_errors: usize,
encoding_warnings: usize,
ddl_dml_errors: usize,
pk_errors: usize,
fk_errors: usize,
tracked_pk_count: usize,
tracked_fk_count: usize,
pk_fk_checks_disabled_due_to_memory: bool,
current_copy_context: Option<(String, Vec<String>, TableId)>,
}
impl Validator {
pub fn new(options: ValidateOptions) -> Self {
Self {
dialect: options.dialect.unwrap_or(SqlDialect::MySql),
options,
issues: Vec::new(),
tables_from_ddl: AHashSet::new(),
tables_from_dml: Vec::new(),
schema_builder: SchemaBuilder::new(),
schema: None,
table_states: AHashMap::new(),
pending_fk_checks: Vec::new(),
progress_fn: None,
statement_count: 0,
syntax_errors: 0,
encoding_warnings: 0,
ddl_dml_errors: 0,
pk_errors: 0,
fk_errors: 0,
tracked_pk_count: 0,
tracked_fk_count: 0,
pk_fk_checks_disabled_due_to_memory: false,
current_copy_context: None,
}
}
pub fn with_progress<F>(mut self, f: F) -> Self
where
F: Fn(u64) + Send + Sync + 'static,
{
self.progress_fn = Some(Arc::new(f));
self
}
fn add_issue(&mut self, issue: ValidationIssue) {
if self.issues.len() >= MAX_ISSUES {
return;
}
match issue.severity {
Severity::Error => match issue.code {
"SYNTAX" => self.syntax_errors += 1,
"DDL_MISSING_TABLE" => self.ddl_dml_errors += 1,
"DUPLICATE_PK" => self.pk_errors += 1,
"FK_MISSING_PARENT" => self.fk_errors += 1,
_ => {}
},
Severity::Warning => {
if issue.code == "ENCODING" {
self.encoding_warnings += 1;
}
}
Severity::Info => {}
}
self.issues.push(issue);
}
fn enforce_pk_fk_memory_budget(&mut self) {
if self.pk_fk_checks_disabled_due_to_memory {
return;
}
let Some(limit) = self.options.max_pk_fk_keys else {
return;
};
let total_tracked = self.tracked_pk_count + self.tracked_fk_count;
if total_tracked > limit {
self.pk_fk_checks_disabled_due_to_memory = true;
for state in self.table_states.values_mut() {
state.pk_values = None;
}
self.pending_fk_checks.clear();
self.pending_fk_checks.shrink_to_fit();
self.add_issue(ValidationIssue::warning(
"PK_FK_CHECKS_SKIPPED_MEMORY",
format!(
"Skipping PK/FK checks after tracking {} keys (memory limit of {} exceeded)",
total_tracked, limit
),
));
}
}
pub fn validate(mut self) -> anyhow::Result<ValidationSummary> {
let file = File::open(&self.options.path)?;
let file_size = file.metadata()?.len();
let buffer_size = determine_buffer_size(file_size);
let compression = Compression::from_path(&self.options.path);
let reader: Box<dyn Read> = if let Some(ref cb) = self.progress_fn {
let cb = Arc::clone(cb);
let progress_reader = ProgressReader::new(file, move |bytes| {
cb(bytes / 2)
});
compression.wrap_reader(Box::new(progress_reader))?
} else {
compression.wrap_reader(Box::new(file))?
};
let mut parser = Parser::with_dialect(reader, buffer_size, self.dialect);
loop {
match parser.read_statement() {
Ok(Some(stmt)) => {
self.statement_count += 1;
self.process_statement(&stmt);
}
Ok(None) => break,
Err(e) => {
self.add_issue(
ValidationIssue::error("SYNTAX", format!("Parser error: {}", e))
.with_location(
Location::new().with_statement(self.statement_count + 1),
),
);
break;
}
}
}
let missing_table_issues: Vec<_> = self
.tables_from_dml
.iter()
.filter(|(table, _)| {
let table_lower = table.to_lowercase();
!self
.tables_from_ddl
.iter()
.any(|t| t.to_lowercase() == table_lower)
})
.map(|(table, stmt_idx)| {
ValidationIssue::error(
"DDL_MISSING_TABLE",
format!(
"INSERT/COPY references table '{}' with no CREATE TABLE",
table
),
)
.with_location(Location::new().with_table(table).with_statement(*stmt_idx))
})
.collect();
for issue in missing_table_issues {
self.add_issue(issue);
}
if self.options.fk_checks_enabled {
self.schema = Some(self.schema_builder.build());
self.schema_builder = SchemaBuilder::new(); self.initialize_table_states();
}
let schema_not_empty = self.schema.as_ref().is_some_and(|s| !s.is_empty());
if self.options.fk_checks_enabled && schema_not_empty {
self.run_data_checks()?;
self.validate_pending_fk_checks();
}
Ok(self.build_summary())
}
fn process_statement(&mut self, stmt: &[u8]) {
if std::str::from_utf8(stmt).is_err() {
self.add_issue(
ValidationIssue::warning("ENCODING", "Statement contains invalid UTF-8 bytes")
.with_location(Location::new().with_statement(self.statement_count)),
);
}
let (stmt_type, table_name) =
Parser::<&[u8]>::parse_statement_with_dialect(stmt, self.dialect);
match stmt_type {
StatementType::CreateTable => {
if !table_name.is_empty() {
self.tables_from_ddl.insert(table_name.clone());
if let Ok(stmt_str) = std::str::from_utf8(stmt) {
self.schema_builder.parse_create_table(stmt_str);
}
}
}
StatementType::AlterTable => {
if let Ok(stmt_str) = std::str::from_utf8(stmt) {
self.schema_builder.parse_alter_table(stmt_str);
}
}
StatementType::Insert | StatementType::Copy => {
if !table_name.is_empty() {
self.tables_from_dml
.push((table_name, self.statement_count));
}
}
StatementType::Unknown => {
}
_ => {}
}
}
fn initialize_table_states(&mut self) {
let schema = match &self.schema {
Some(s) => s,
None => return,
};
for table_schema in schema.iter() {
let pk_indices: Vec<usize> = table_schema
.primary_key
.iter()
.map(|col_id| col_id.0 as usize)
.collect();
let state = TableState::new().with_pk_columns(pk_indices);
self.table_states.insert(table_schema.id, state);
}
}
fn run_data_checks(&mut self) -> anyhow::Result<()> {
let file = File::open(&self.options.path)?;
let file_size = file.metadata()?.len();
let buffer_size = determine_buffer_size(file_size);
let compression = Compression::from_path(&self.options.path);
let reader: Box<dyn Read> = if let Some(ref cb) = self.progress_fn {
let cb = Arc::clone(cb);
let progress_reader = ProgressReader::new(file, move |bytes| {
cb(file_size / 2 + bytes / 2)
});
compression.wrap_reader(Box::new(progress_reader))?
} else {
compression.wrap_reader(Box::new(file))?
};
let mut parser = Parser::with_dialect(reader, buffer_size, self.dialect);
let mut stmt_count: u64 = 0;
self.current_copy_context = None;
while let Some(stmt) = parser.read_statement()? {
stmt_count += 1;
let (stmt_type, table_name) =
Parser::<&[u8]>::parse_statement_with_dialect(&stmt, self.dialect);
if self.dialect == SqlDialect::Postgres && stmt_type == StatementType::Unknown {
if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
if let Some((ref copy_table, ref column_order, copy_table_id)) =
self.current_copy_context.clone()
{
self.check_copy_data(
&stmt,
copy_table_id,
copy_table,
column_order.clone(),
stmt_count,
);
}
}
self.current_copy_context = None;
continue;
}
let table_id = match &self.schema {
Some(s) => match s.get_table_id(&table_name) {
Some(id) => id,
None => continue,
},
None => continue,
};
match stmt_type {
StatementType::Insert => {
self.check_insert_statement(&stmt, table_id, &table_name, stmt_count);
}
StatementType::Copy => {
let header = String::from_utf8_lossy(&stmt);
let column_order = postgres_copy::parse_copy_columns(&header);
self.current_copy_context = Some((table_name.clone(), column_order, table_id));
}
_ => continue,
}
}
Ok(())
}
fn check_insert_statement(
&mut self,
stmt: &[u8],
table_id: TableId,
table_name: &str,
stmt_count: u64,
) {
let table_schema = match &self.schema {
Some(s) => match s.table(table_id) {
Some(ts) => ts,
None => return,
},
None => return,
};
let rows = match mysql_insert::parse_mysql_insert_rows(stmt, table_schema) {
Ok(r) => r,
Err(_) => return,
};
for row in rows {
self.check_mysql_row(table_id, table_name, &row, stmt_count);
}
}
#[allow(dead_code)]
fn check_copy_statement(
&mut self,
stmt: &[u8],
table_id: TableId,
table_name: &str,
stmt_count: u64,
) {
let stmt_str = match std::str::from_utf8(stmt) {
Ok(s) => s,
Err(_) => return,
};
let data_start = if let Some(pos) = stmt_str.find("FROM stdin;") {
pos + "FROM stdin;".len()
} else if let Some(pos) = stmt_str.find("from stdin;") {
pos + "from stdin;".len()
} else {
return;
};
let data_section = stmt_str[data_start..].trim_start();
if data_section.is_empty() {
return;
}
let header = &stmt_str[..data_start];
let column_order = postgres_copy::parse_copy_columns(header);
let table_schema = match &self.schema {
Some(s) => match s.table(table_id) {
Some(ts) => ts,
None => return,
},
None => return,
};
let rows = match postgres_copy::parse_postgres_copy_rows(
data_section.as_bytes(),
table_schema,
column_order,
) {
Ok(r) => r,
Err(_) => return,
};
for row in rows {
self.check_copy_row(table_id, table_name, &row, stmt_count);
}
}
fn check_copy_data(
&mut self,
data_stmt: &[u8],
table_id: TableId,
table_name: &str,
column_order: Vec<String>,
stmt_count: u64,
) {
let data: Vec<u8> = data_stmt
.iter()
.skip_while(|&&b| b == b'\n' || b == b'\r' || b == b' ' || b == b'\t')
.cloned()
.collect();
if data.is_empty() {
return;
}
let table_schema = match &self.schema {
Some(s) => match s.table(table_id) {
Some(ts) => ts,
None => return,
},
None => return,
};
let rows = match postgres_copy::parse_postgres_copy_rows(&data, table_schema, column_order)
{
Ok(r) => r,
Err(_) => return,
};
for row in rows {
self.check_copy_row(table_id, table_name, &row, stmt_count);
}
}
fn check_mysql_row(
&mut self,
table_id: TableId,
table_name: &str,
row: &mysql_insert::ParsedRow,
stmt_idx: u64,
) {
self.check_row_common(
table_id,
table_name,
row.pk.as_ref(),
&row.fk_values,
stmt_idx,
);
}
fn check_copy_row(
&mut self,
table_id: TableId,
table_name: &str,
row: &postgres_copy::ParsedCopyRow,
stmt_idx: u64,
) {
self.check_row_common(
table_id,
table_name,
row.pk.as_ref(),
&row.fk_values,
stmt_idx,
);
}
fn check_row_common(
&mut self,
table_id: TableId,
table_name: &str,
pk: Option<&smallvec::SmallVec<[mysql_insert::PkValue; 2]>>,
fk_values: &[(
mysql_insert::FkRef,
smallvec::SmallVec<[mysql_insert::PkValue; 2]>,
)],
stmt_idx: u64,
) {
if self.pk_fk_checks_disabled_due_to_memory {
return;
}
let max_rows = self.options.max_rows_per_table as u64;
let state = match self.table_states.get_mut(&table_id) {
Some(s) => s,
None => return,
};
state.row_count += 1;
if state.row_count > max_rows {
if state.pk_values.is_some() {
state.pk_values = None;
self.add_issue(
ValidationIssue::warning(
"PK_CHECK_SKIPPED",
format!(
"Skipping PK/FK checks for table '{}' after {} rows (increase --max-rows-per-table)",
table_name, max_rows
),
)
.with_location(Location::new().with_table(table_name)),
);
}
return;
}
if let Some(pk_values) = pk {
if let Some(ref mut pk_set) = state.pk_values {
let pk_hash = hash_pk_values(pk_values);
if pk_set.insert(pk_hash) {
self.tracked_pk_count += 1;
self.enforce_pk_fk_memory_budget();
} else {
state.pk_duplicates += 1;
let pk_display: String = pk_values
.iter()
.map(|v| match v {
mysql_insert::PkValue::Int(i) => i.to_string(),
mysql_insert::PkValue::BigInt(i) => i.to_string(),
mysql_insert::PkValue::Text(s) => s.to_string(),
mysql_insert::PkValue::Null => "NULL".to_string(),
})
.collect::<Vec<_>>()
.join(", ");
self.add_issue(
ValidationIssue::error(
"DUPLICATE_PK",
format!(
"Duplicate primary key in table '{}': ({})",
table_name, pk_display
),
)
.with_location(
Location::new()
.with_table(table_name)
.with_statement(stmt_idx),
),
);
}
}
}
if self.pk_fk_checks_disabled_due_to_memory {
return;
}
let new_fk_checks: Vec<PendingFkCheck> = {
let schema = match &self.schema {
Some(s) => s,
None => return,
};
let table_schema = match schema.table(table_id) {
Some(t) => t,
None => return,
};
fk_values
.iter()
.filter_map(|(fk_ref, fk_vals)| {
if fk_vals.iter().all(|v| v.is_null()) {
return None;
}
let fk_def = table_schema.foreign_keys.get(fk_ref.fk_index as usize)?;
let parent_table_id = fk_def.referenced_table_id?;
let fk_hash = hash_pk_values(fk_vals);
Some(PendingFkCheck {
child_table_id: table_id,
parent_table_id,
fk_hash,
stmt_idx,
})
})
.collect()
};
let new_count = new_fk_checks.len();
self.pending_fk_checks.extend(new_fk_checks);
self.tracked_fk_count += new_count;
if new_count > 0 {
self.enforce_pk_fk_memory_budget();
}
}
fn validate_pending_fk_checks(&mut self) {
for check in std::mem::take(&mut self.pending_fk_checks) {
let parent_has_pk = self
.table_states
.get(&check.parent_table_id)
.and_then(|s| s.pk_values.as_ref())
.is_some_and(|set| set.contains(&check.fk_hash));
if !parent_has_pk {
let state = match self.table_states.get_mut(&check.child_table_id) {
Some(s) => s,
None => continue,
};
state.fk_missing_parents += 1;
if state.fk_missing_parents <= 5 {
let (child_name, parent_name) = if let Some(schema) = &self.schema {
let child = schema
.table(check.child_table_id)
.map(|t| t.name.clone())
.unwrap_or_else(|| "<unknown>".to_string());
let parent = schema
.table(check.parent_table_id)
.map(|t| t.name.clone())
.unwrap_or_else(|| "<unknown>".to_string());
(child, parent)
} else {
("<unknown>".to_string(), "<unknown>".to_string())
};
self.add_issue(
ValidationIssue::error(
"FK_MISSING_PARENT",
format!(
"FK violation in '{}': references missing row in '{}'",
child_name, parent_name
),
)
.with_location(
Location::new()
.with_table(child_name)
.with_statement(check.stmt_idx),
),
);
}
}
}
}
fn build_summary(&self) -> ValidationSummary {
let errors = self
.issues
.iter()
.filter(|i| matches!(i.severity, Severity::Error))
.count();
let warnings = self
.issues
.iter()
.filter(|i| matches!(i.severity, Severity::Warning))
.count();
let info = self
.issues
.iter()
.filter(|i| matches!(i.severity, Severity::Info))
.count();
let syntax_status = if self.syntax_errors > 0 {
CheckStatus::Failed(self.syntax_errors)
} else {
CheckStatus::Ok
};
let encoding_status = if self.encoding_warnings > 0 {
CheckStatus::Failed(self.encoding_warnings)
} else {
CheckStatus::Ok
};
let ddl_dml_status = if self.ddl_dml_errors > 0 {
CheckStatus::Failed(self.ddl_dml_errors)
} else {
CheckStatus::Ok
};
let pk_status = if !self.options.fk_checks_enabled {
CheckStatus::Skipped("--no-fk-checks".to_string())
} else if self.pk_fk_checks_disabled_due_to_memory {
CheckStatus::Skipped("memory limit exceeded".to_string())
} else if self.pk_errors > 0 {
CheckStatus::Failed(self.pk_errors)
} else {
CheckStatus::Ok
};
let fk_status = if !self.options.fk_checks_enabled {
CheckStatus::Skipped("--no-fk-checks".to_string())
} else if self.pk_fk_checks_disabled_due_to_memory {
CheckStatus::Skipped("memory limit exceeded".to_string())
} else if self.fk_errors > 0 {
CheckStatus::Failed(self.fk_errors)
} else {
CheckStatus::Ok
};
ValidationSummary {
dialect: self.dialect.to_string(),
issues: self.issues.clone(),
summary: SummaryStats {
errors,
warnings,
info,
tables_scanned: self.tables_from_ddl.len(),
statements_scanned: self.statement_count,
},
checks: CheckResults {
syntax: syntax_status,
encoding: encoding_status,
ddl_dml_consistency: ddl_dml_status,
pk_duplicates: pk_status,
fk_integrity: fk_status,
},
}
}
}