use super::*;
pub(super) fn parse_column_expr(input: &str) -> Result<CompatibilityColumnExpr, String> {
let trimmed = input.trim();
if trimmed.is_empty() {
return Err("missing column expression".to_string());
}
let mut index = 0usize;
let mut segments = vec![parse_identifier_segment(trimmed, &mut index)?];
while index < trimmed.len() {
let bytes = trimmed.as_bytes();
if bytes[index] == b'.' {
index += 1;
segments.push(parse_identifier_segment(trimmed, &mut index)?);
continue;
}
break;
}
let cast = parse_optional_trailing_cast(trimmed, &mut index)?;
if index != trimmed.len() {
return Err("unsupported column expression in compatibility mode".to_string());
}
Ok(CompatibilityColumnExpr {
path: segments.join("."),
cast,
})
}
pub(super) fn parse_json_value_literal(input: &str) -> Result<CompatibilityValueExpr, String> {
let trimmed = input.trim();
if trimmed.is_empty() {
return Err("missing value".to_string());
}
let mut index = 0usize;
let value = if trimmed.as_bytes()[0] == b'\'' {
serde_json::Value::String(parse_string_literal(trimmed, &mut index)?)
} else {
let bytes = trimmed.as_bytes();
let start = index;
while index < bytes.len() && !bytes[index].is_ascii_whitespace() && bytes[index] != b':' {
index += 1;
}
let token = &trimmed[start..index];
if token.eq_ignore_ascii_case("null") {
serde_json::Value::Null
} else if token.eq_ignore_ascii_case("true") {
serde_json::Value::Bool(true)
} else if token.eq_ignore_ascii_case("false") {
serde_json::Value::Bool(false)
} else if let Ok(number) = token.parse::<i64>() {
json!(number)
} else if let Ok(number) = token.parse::<f64>() {
match serde_json::Number::from_f64(number) {
Some(number) => serde_json::Value::Number(number),
None => return Err("invalid numeric literal".to_string()),
}
} else {
return Err("unsupported literal in compatibility mode".to_string());
}
};
let cast = parse_optional_trailing_cast(trimmed, &mut index)?;
if index != trimmed.len() {
return Err("unsupported literal in compatibility mode".to_string());
}
Ok(CompatibilityValueExpr { value, cast })
}
pub(super) fn parse_optional_trailing_cast(
input: &str,
index: &mut usize,
) -> Result<Option<String>, String> {
let bytes = input.as_bytes();
let mut cast = None;
while *index < bytes.len() {
while *index < bytes.len() && bytes[*index].is_ascii_whitespace() {
*index += 1;
}
if bytes.get(*index) != Some(&b':') || bytes.get(*index + 1) != Some(&b':') {
break;
}
*index += 2;
let mut segments = vec![parse_identifier_segment(input, index)?];
while *index < bytes.len() && bytes[*index] == b'.' {
*index += 1;
segments.push(parse_identifier_segment(input, index)?);
}
cast = Some(segments.join("."));
}
while *index < bytes.len() && bytes[*index].is_ascii_whitespace() {
*index += 1;
}
Ok(cast)
}
pub(super) fn parse_identifier_segment(input: &str, index: &mut usize) -> Result<String, String> {
let bytes = input.as_bytes();
if *index >= bytes.len() {
return Err("missing identifier".to_string());
}
if bytes[*index] == b'"' {
*index += 1;
let mut segment = String::new();
while *index < bytes.len() {
let current = bytes[*index];
if current == b'"' {
if bytes.get(*index + 1) == Some(&b'"') {
segment.push('"');
*index += 2;
continue;
}
*index += 1;
return Ok(segment);
}
segment.push(current as char);
*index += 1;
}
return Err("unterminated quoted identifier".to_string());
}
let start = *index;
while *index < bytes.len() {
let current = bytes[*index];
if current.is_ascii_alphanumeric() || current == b'_' {
*index += 1;
continue;
}
break;
}
if *index == start {
return Err("missing identifier".to_string());
}
Ok(input[start..*index].to_string())
}
pub(super) fn parse_string_literal(input: &str, index: &mut usize) -> Result<String, String> {
let bytes = input.as_bytes();
if bytes.get(*index) != Some(&b'\'') {
return Err("expected string literal".to_string());
}
*index += 1;
let mut value = String::new();
while *index < bytes.len() {
let current = bytes[*index];
if current == b'\'' {
if bytes.get(*index + 1) == Some(&b'\'') {
value.push('\'');
*index += 2;
continue;
}
*index += 1;
return Ok(value);
}
value.push(current as char);
*index += 1;
}
Err("unterminated string literal".to_string())
}
pub(super) fn validate_foreign_key_hint(raw: &str) -> Result<(), String> {
if let Some(stripped) = raw.strip_prefix("parent.") {
validate_identifier(stripped, "foreign_key")?;
return Ok(());
}
if let Some(stripped) = raw.strip_prefix("child.") {
validate_identifier(stripped, "foreign_key")?;
return Ok(());
}
validate_identifier(raw, "foreign_key")
}
pub(super) fn validate_identifier(identifier: &str, label: &str) -> Result<(), String> {
sanitize_identifier(identifier)
.map(|_| ())
.ok_or_else(|| format!("{label} '{identifier}' must be a valid SQL identifier"))
}
pub(super) fn sanitize_identifier(identifier: &str) -> Option<String> {
let trimmed = identifier.trim();
if trimmed.is_empty() {
return None;
}
if trimmed
.bytes()
.all(|byte| byte.is_ascii_alphanumeric() || byte == b'_')
{
Some(format!("\"{trimmed}\""))
} else {
None
}
}
pub(super) fn render_table_ref(table: &GatewayRelationSelectTableRef) -> String {
match table.schema_name.as_deref() {
Some(schema_name) => format!("{schema_name}.{}", table.table_name),
None => table.table_name.clone(),
}
}