use crate::numeric::StrToF64;
use crate::schema::ColDef;
use crate::translate::emitter::TransactionMode;
use crate::translate::expr::{walk_expr, walk_expr_mut, WalkControl};
use crate::translate::plan::{BitSet, JoinedTable, TableReferences};
use crate::translate::planner::{parse_row_id, TableMask};
use crate::types::IOResult;
use crate::IO;
use crate::{
schema::{Column, Schema, Table, Type},
types::{Value, ValueType},
LimboError, OpenFlags, Result, Statement, SymbolTable,
};
use either::Either;
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
use std::future::Future;
use tracing::{instrument, Level};
use turso_macros::match_ignore_ascii_case;
use turso_parser::ast::{self, CreateTableBody, Expr, Literal, UnaryOperator};
use turso_parser::parser::Parser;
#[macro_export]
macro_rules! io_yield_one {
($c:expr) => {
return Ok(IOResult::IO(IOCompletions::Single($c)));
};
}
#[macro_export]
macro_rules! eq_ignore_ascii_case {
( $var:expr, $value:literal ) => {{
match_ignore_ascii_case!(match $var {
$value => true,
_ => false,
})
}};
}
#[macro_export]
macro_rules! contains_ignore_ascii_case {
( $var:expr, $value:literal ) => {{
let compare_to_idx = $var.len().saturating_sub($value.len());
if $var.len() < $value.len() {
false
} else {
let mut result = false;
for i in 0..=compare_to_idx {
if eq_ignore_ascii_case!(&$var[i..i + $value.len()], $value) {
result = true;
break;
}
}
result
}
}};
}
#[macro_export]
macro_rules! starts_with_ignore_ascii_case {
( $var:expr, $value:literal ) => {{
if $var.len() < $value.len() {
false
} else {
eq_ignore_ascii_case!(&$var[..$value.len()], $value)
}
}};
}
#[macro_export]
macro_rules! ends_with_ignore_ascii_case {
( $var:expr, $value:literal ) => {{
if $var.len() < $value.len() {
false
} else {
eq_ignore_ascii_case!(&$var[$var.len() - $value.len()..], $value)
}
}};
}
pub trait IOExt {
fn block<T>(&self, f: impl FnMut() -> Result<IOResult<T>>) -> Result<T>;
fn wait<T, F>(&self, f: F) -> impl Future<Output = Result<T>> + Send
where
F: FnMut() -> Result<IOResult<T>> + Send,
T: Send;
}
impl<I: ?Sized + IO> IOExt for I {
fn block<T>(&self, mut f: impl FnMut() -> Result<IOResult<T>>) -> Result<T> {
Ok(loop {
match f()? {
IOResult::Done(v) => break v,
IOResult::IO(io) => io.wait(self)?,
}
})
}
async fn wait<T, F>(&self, mut f: F) -> Result<T>
where
F: FnMut() -> Result<IOResult<T>> + Send,
T: Send,
{
Ok(loop {
match f()? {
IOResult::Done(v) => break v,
IOResult::IO(io) => io.wait_async(self).await?,
}
})
}
}
const QUOTE_PAIRS: &[(char, char)] = &[
('"', '"'),
('[', ']'),
('`', '`'),
('\'', '\''), ];
pub fn normalize_ident(identifier: &str) -> String {
identifier.to_lowercase()
}
pub fn escape_sql_string_literal(literal: &str) -> String {
literal.replace('\'', "''")
}
pub fn quote_identifier(name: &str) -> String {
let needs_quoting = name.is_empty()
|| name.as_bytes()[0].is_ascii_digit()
|| !name.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'_')
|| turso_parser::lexer::is_quotable_keyword(name.as_bytes());
if needs_quoting {
let escaped = name.replace('"', "\"\"");
format!("\"{escaped}\"")
} else {
name.to_string()
}
}
pub const PRIMARY_KEY_AUTOMATIC_INDEX_NAME_PREFIX: &str = "sqlite_autoindex_";
pub struct UnparsedFromSqlIndex {
pub table_name: String,
pub root_page: i64,
pub sql: String,
}
#[instrument(skip_all, level = Level::INFO)]
pub fn parse_schema_rows(
mut rows: Statement,
schema: &mut Schema,
syms: &SymbolTable,
mv_tx: Option<(u64, TransactionMode)>,
resolve_attached_db: &dyn Fn(&str) -> Option<usize>,
) -> Result<()> {
rows.set_mv_tx(mv_tx);
let mv_store = rows.mv_store().clone();
let mut from_sql_indexes = Vec::with_capacity(10);
let mut automatic_indices = HashMap::with_capacity_and_hasher(10, Default::default());
let mut dbsp_state_roots: HashMap<String, i64> = HashMap::default();
let mut dbsp_state_index_roots: HashMap<String, i64> = HashMap::default();
let mut materialized_view_info: HashMap<String, (String, i64)> = HashMap::default();
rows.run_with_row_callback(|row| {
let ty = row.get::<&str>(0)?;
let name = row.get::<&str>(1)?;
let table_name = row.get::<&str>(2)?;
let root_page = row.get::<i64>(3)?;
let sql = row.get::<&str>(4).ok();
schema.handle_schema_row(
ty,
name,
table_name,
root_page,
sql,
syms,
&mut from_sql_indexes,
&mut automatic_indices,
&mut dbsp_state_roots,
&mut dbsp_state_index_roots,
&mut materialized_view_info,
resolve_attached_db,
)
})?;
schema.populate_indices(
syms,
from_sql_indexes,
automatic_indices,
mv_store.is_some(),
)?;
schema.populate_materialized_views(
materialized_view_info,
dbsp_state_roots,
dbsp_state_index_roots,
)?;
Ok(())
}
fn cmp_numeric_strings(num_str: &str, other: &str) -> bool {
fn parse(s: &str) -> Option<Either<i64, f64>> {
if let Ok(i) = s.parse::<i64>() {
Some(Either::Left(i))
} else if let Ok(f) = s.parse::<f64>() {
Some(Either::Right(f))
} else {
None
}
}
match (parse(num_str), parse(other)) {
(Some(Either::Left(i1)), Some(Either::Left(i2))) => i1 == i2,
(Some(Either::Right(f1)), Some(Either::Right(f2))) => f1 == f2,
(Some(Either::Left(_)), Some(Either::Right(_)))
| (Some(Either::Right(_)), Some(Either::Left(_))) => false,
_ => num_str == other,
}
}
pub fn check_ident_equivalency(ident1: &str, ident2: &str) -> bool {
fn strip_quotes(identifier: &str) -> &str {
for &(start, end) in QUOTE_PAIRS {
if identifier.starts_with(start) && identifier.ends_with(end) {
return &identifier[1..identifier.len() - 1];
}
}
identifier
}
strip_quotes(ident1).eq_ignore_ascii_case(strip_quotes(ident2))
}
pub fn module_name_from_sql(sql: &str) -> Result<&str> {
if let Some(start) = sql.find("USING") {
let start = start + 6;
let end = sql[start..]
.find(|c: char| c.is_whitespace() || c == ';' || c == '(')
.unwrap_or(sql.len() - start)
+ start;
Ok(sql[start..end].trim())
} else {
Err(LimboError::InvalidArgument(
"Expected 'USING' in module name".to_string(),
))
}
}
pub fn module_args_from_sql(sql: &str) -> Result<Vec<turso_ext::Value>> {
if !sql.contains('(') {
return Ok(vec![]);
}
let start = sql.find('(').ok_or_else(|| {
LimboError::InvalidArgument("Expected '(' in module argument list".to_string())
})? + 1;
let end = sql.rfind(')').ok_or_else(|| {
LimboError::InvalidArgument("Expected ')' in module argument list".to_string())
})?;
let mut args = Vec::new();
let mut current_arg = String::new();
let mut chars = sql[start..end].chars().peekable();
let mut in_quotes = false;
while let Some(c) = chars.next() {
match c {
'\'' => {
if in_quotes {
if chars.peek() == Some(&'\'') {
current_arg.push('\'');
chars.next();
} else {
in_quotes = false;
args.push(turso_ext::Value::from_text(current_arg.trim().to_string()));
current_arg.clear();
while let Some(&nc) = chars.peek() {
if nc == ',' {
chars.next(); break;
} else if nc.is_whitespace() {
chars.next();
} else {
return Err(LimboError::InvalidArgument(
"Unexpected characters after quoted argument".to_string(),
));
}
}
}
} else {
in_quotes = true;
}
}
',' => {
if !in_quotes {
if !current_arg.trim().is_empty() {
args.push(turso_ext::Value::from_text(current_arg.trim().to_string()));
current_arg.clear();
}
} else {
current_arg.push(c);
}
}
_ => {
current_arg.push(c);
}
}
}
if !current_arg.trim().is_empty() && !in_quotes {
args.push(turso_ext::Value::from_text(current_arg.trim().to_string()));
}
if in_quotes {
return Err(LimboError::InvalidArgument(
"Unterminated string literal in module arguments".to_string(),
));
}
Ok(args)
}
pub fn check_literal_equivalency(lhs: &Literal, rhs: &Literal) -> bool {
match (lhs, rhs) {
(Literal::Numeric(n1), Literal::Numeric(n2)) => cmp_numeric_strings(n1, n2),
(Literal::String(s1), Literal::String(s2)) => s1 == s2,
(Literal::Blob(b1), Literal::Blob(b2)) => b1 == b2,
(Literal::Keyword(k1), Literal::Keyword(k2)) => check_ident_equivalency(k1, k2),
(Literal::Null, Literal::Null) => true,
(Literal::CurrentDate, Literal::CurrentDate) => true,
(Literal::CurrentTime, Literal::CurrentTime) => true,
(Literal::CurrentTimestamp, Literal::CurrentTimestamp) => true,
_ => false,
}
}
pub(crate) fn expr_tables_subset_of(
expr: &Expr,
table_references: &TableReferences,
allowed: &TableMask,
) -> bool {
let mut ok = true;
let _ = walk_expr(expr, &mut |e: &Expr| -> Result<WalkControl> {
match e {
Expr::Column { table, .. } | Expr::RowId { table, .. } => {
if let Some(idx) = table_references
.joined_tables()
.iter()
.position(|t| t.internal_id == *table)
{
if !allowed.get(idx) {
ok = false;
return Ok(WalkControl::SkipChildren);
}
}
}
_ => {}
}
Ok(WalkControl::Continue)
});
ok
}
pub fn simple_bind_expr(
joined_table: &JoinedTable,
result_columns: &[ast::ResultColumn],
expr: &mut ast::Expr,
) -> Result<()> {
let internal_id = joined_table.internal_id;
walk_expr_mut(expr, &mut |expr: &mut ast::Expr| -> Result<WalkControl> {
#[allow(clippy::single_match)]
match expr {
Expr::Id(id) => {
for result_column in result_columns.iter() {
if let ast::ResultColumn::Expr(result, Some(ast::As::As(alias))) = result_column
{
if alias.as_str().eq_ignore_ascii_case(id.as_str()) {
*expr = *result.clone();
return Ok(WalkControl::Continue);
}
}
}
let col_idx = joined_table.columns().iter().position(|c| {
c.name
.as_ref()
.is_some_and(|name| name.eq_ignore_ascii_case(id.as_str()))
});
if let Some(col_idx) = col_idx {
let col = joined_table.table.columns().get(col_idx).unwrap();
*expr = ast::Expr::Column {
database: None,
table: internal_id,
column: col_idx,
is_rowid_alias: col.is_rowid_alias(),
};
} else {
let is_btree_table = matches!(joined_table.table, Table::BTree(_));
if is_btree_table {
if let Some(rowid) =
parse_row_id(&normalize_ident(id.as_str()), internal_id, || false)?
{
*expr = rowid;
}
}
}
}
_ => {}
}
Ok(WalkControl::Continue)
})?;
Ok(())
}
pub fn try_substitute_parameters(
pattern: &Expr,
parameters: &HashMap<i32, Expr>,
) -> Option<Box<Expr>> {
match pattern {
Expr::FunctionCall {
name,
distinctness,
args,
order_by,
filter_over,
} => {
let mut substituted = Vec::new();
for arg in args {
substituted.push(try_substitute_parameters(arg, parameters)?);
}
Some(Box::new(Expr::FunctionCall {
args: substituted,
distinctness: *distinctness,
name: name.clone(),
order_by: order_by.clone(),
filter_over: filter_over.clone(),
}))
}
Expr::Variable(var) => {
if var.name.is_some() {
return None;
}
let Ok(var) = i32::try_from(var.index.get()) else {
return None;
};
Some(Box::new(parameters.get(&var)?.clone()))
}
_ => Some(Box::new(pattern.clone())),
}
}
pub fn try_capture_parameters(pattern: &Expr, query: &Expr) -> Option<HashMap<i32, Expr>> {
let mut captured = HashMap::default();
match (pattern, query) {
(
Expr::FunctionCall {
name: name1,
distinctness: distinct1,
args: args1,
order_by: order1,
filter_over: filter1,
},
Expr::FunctionCall {
name: name2,
distinctness: distinct2,
args: args2,
order_by: order2,
filter_over: filter2,
},
) => {
if !name1.as_str().eq_ignore_ascii_case(name2.as_str()) {
return None;
}
if distinct1.is_some() || distinct2.is_some() {
return None;
}
if !order1.is_empty() || !order2.is_empty() {
return None;
}
if filter1.filter_clause.is_some() || filter1.over_clause.is_some() {
return None;
}
if filter2.filter_clause.is_some() || filter2.over_clause.is_some() {
return None;
}
for (arg1, arg2) in args1.iter().zip(args2.iter()) {
let result = try_capture_parameters(arg1, arg2)?;
captured.extend(result);
}
Some(captured)
}
(Expr::Variable(var), expr) => {
if var.name.is_some() {
return None;
}
let Ok(var) = i32::try_from(var.index.get()) else {
return None;
};
captured.insert(var, expr.clone());
Some(captured)
}
(
Expr::Id(_) | Expr::Name(_) | Expr::Column { .. },
Expr::Id(_) | Expr::Name(_) | Expr::Column { .. },
) => {
if pattern == query {
Some(captured)
} else {
None
}
}
(_, _) => None,
}
}
pub fn count_fts_column_args(expr: &Expr) -> usize {
match expr {
Expr::FunctionCall { name, args, .. } => {
let name_lower = name.as_str().to_lowercase();
match name_lower.as_str() {
"fts_match" | "fts_score" => args.len().saturating_sub(1),
"fts_highlight" => args.len().saturating_sub(3),
_ => 0,
}
}
_ => 0,
}
}
pub fn try_capture_parameters_column_agnostic(
pattern: &Expr, query: &Expr, num_column_args: usize, ) -> Option<HashMap<i32, Expr>> {
if num_column_args == 0 {
return try_capture_parameters(pattern, query);
}
let (
Expr::FunctionCall {
name: pattern_name,
distinctness: pattern_distinct,
args: pattern_args,
order_by: pattern_order,
filter_over: pattern_filter,
},
Expr::FunctionCall {
name: query_name,
distinctness: query_distinct,
args: query_args,
order_by: query_order,
filter_over: query_filter,
},
) = (pattern, query)
else {
return try_capture_parameters(pattern, query);
};
if !pattern_name
.as_str()
.eq_ignore_ascii_case(query_name.as_str())
{
return None;
}
if pattern_args.len() != query_args.len() {
return None;
}
if pattern_distinct.is_some() || query_distinct.is_some() {
return None;
}
if !pattern_order.is_empty() || !query_order.is_empty() {
return None;
}
if pattern_filter.filter_clause.is_some() || pattern_filter.over_clause.is_some() {
return None;
}
if query_filter.filter_clause.is_some() || query_filter.over_clause.is_some() {
return None;
}
let mut captured = HashMap::default();
let pattern_col_args = &pattern_args[..num_column_args];
let query_col_args = &query_args[..num_column_args];
let pattern_rest = &pattern_args[num_column_args..];
let query_rest = &query_args[num_column_args..];
let mut matched_pattern_indices = BitSet::default();
for query_col in query_col_args {
let mut found_match = false;
for (i, pattern_col) in pattern_col_args.iter().enumerate() {
if matched_pattern_indices.get(i) {
continue;
}
if exprs_are_equivalent(pattern_col, query_col) {
matched_pattern_indices.set(i);
found_match = true;
break;
}
}
if !found_match {
return None;
}
}
if matched_pattern_indices.count() != pattern_col_args.len() {
return None;
}
for (pattern_arg, query_arg) in pattern_rest.iter().zip(query_rest.iter()) {
let result = try_capture_parameters(pattern_arg, query_arg)?;
captured.extend(result);
}
Some(captured)
}
pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool {
match (expr1, expr2) {
(
Expr::Between {
lhs: lhs1,
not: not1,
start: start1,
end: end1,
},
Expr::Between {
lhs: lhs2,
not: not2,
start: start2,
end: end2,
},
) => {
not1 == not2
&& exprs_are_equivalent(lhs1, lhs2)
&& exprs_are_equivalent(start1, start2)
&& exprs_are_equivalent(end1, end2)
}
(Expr::Binary(lhs1, op1, rhs1), Expr::Binary(lhs2, op2, rhs2)) => {
op1 == op2
&& ((exprs_are_equivalent(lhs1, lhs2) && exprs_are_equivalent(rhs1, rhs2))
|| (op1.is_commutative()
&& exprs_are_equivalent(lhs1, rhs2)
&& exprs_are_equivalent(rhs1, lhs2)))
}
(
Expr::Case {
base: base1,
when_then_pairs: pairs1,
else_expr: else1,
},
Expr::Case {
base: base2,
when_then_pairs: pairs2,
else_expr: else2,
},
) => {
base1 == base2
&& pairs1.len() == pairs2.len()
&& pairs1.iter().zip(pairs2).all(|((w1, t1), (w2, t2))| {
exprs_are_equivalent(w1, w2) && exprs_are_equivalent(t1, t2)
})
&& else1 == else2
}
(
Expr::Cast {
expr: expr1,
type_name: type1,
},
Expr::Cast {
expr: expr2,
type_name: type2,
},
) => {
exprs_are_equivalent(expr1, expr2)
&& match (type1, type2) {
(Some(t1), Some(t2)) => t1.name.eq_ignore_ascii_case(&t2.name),
_ => false,
}
}
(Expr::Collate(expr1, collation1), Expr::Collate(expr2, collation2)) => {
exprs_are_equivalent(expr1, expr2)
&& collation1
.as_str()
.eq_ignore_ascii_case(collation2.as_str())
}
(
Expr::FunctionCall {
name: name1,
distinctness: distinct1,
args: args1,
order_by: order1,
filter_over: filter1,
},
Expr::FunctionCall {
name: name2,
distinctness: distinct2,
args: args2,
order_by: order2,
filter_over: filter2,
},
) => {
name1.as_str().eq_ignore_ascii_case(name2.as_str())
&& distinct1 == distinct2
&& args1 == args2
&& order1 == order2
&& filter1 == filter2
}
(
Expr::FunctionCallStar {
name: name1,
filter_over: filter1,
},
Expr::FunctionCallStar {
name: name2,
filter_over: filter2,
},
) => {
name1.as_str().eq_ignore_ascii_case(name2.as_str())
&& match (&filter1.filter_clause, &filter2.filter_clause) {
(Some(expr1), Some(expr2)) => exprs_are_equivalent(expr1, expr2),
(None, None) => true,
_ => false,
}
&& filter1.over_clause == filter2.over_clause
}
(Expr::NotNull(expr1), Expr::NotNull(expr2)) => exprs_are_equivalent(expr1, expr2),
(Expr::IsNull(expr1), Expr::IsNull(expr2)) => exprs_are_equivalent(expr1, expr2),
(Expr::Literal(lit1), Expr::Literal(lit2)) => check_literal_equivalency(lit1, lit2),
(Expr::Id(id1), Expr::Id(id2)) => check_ident_equivalency(id1.as_str(), id2.as_str()),
(Expr::Unary(op1, expr1), Expr::Unary(op2, expr2)) => {
op1 == op2 && exprs_are_equivalent(expr1, expr2)
}
(Expr::Variable(val), Expr::Variable(val2)) => val == val2,
(Expr::Parenthesized(exprs1), Expr::Parenthesized(exprs2)) => {
exprs1.len() == exprs2.len()
&& exprs1
.iter()
.zip(exprs2)
.all(|(e1, e2)| exprs_are_equivalent(e1, e2))
}
(Expr::Parenthesized(exprs1), exprs2) | (exprs2, Expr::Parenthesized(exprs1)) => {
exprs1.len() == 1 && exprs_are_equivalent(&exprs1[0], exprs2)
}
(Expr::Qualified(tn1, cn1), Expr::Qualified(tn2, cn2)) => {
check_ident_equivalency(tn1.as_str(), tn2.as_str())
&& check_ident_equivalency(cn1.as_str(), cn2.as_str())
}
(Expr::DoublyQualified(sn1, tn1, cn1), Expr::DoublyQualified(sn2, tn2, cn2)) => {
check_ident_equivalency(sn1.as_str(), sn2.as_str())
&& check_ident_equivalency(tn1.as_str(), tn2.as_str())
&& check_ident_equivalency(cn1.as_str(), cn2.as_str())
}
(
Expr::InList {
lhs: lhs1,
not: not1,
rhs: rhs1,
},
Expr::InList {
lhs: lhs2,
not: not2,
rhs: rhs2,
},
) => {
*not1 == *not2
&& exprs_are_equivalent(lhs1, lhs2)
&& rhs1.len() == rhs2.len()
&& rhs1
.iter()
.zip(rhs2.iter())
.all(|(a, b)| exprs_are_equivalent(a, b))
}
(
Expr::Column {
database: db1,
is_rowid_alias: r1,
table: tbl_1,
column: col_1,
},
Expr::Column {
database: db2,
is_rowid_alias: r2,
table: tbl_2,
column: col_2,
},
) => tbl_1 == tbl_2 && col_1 == col_2 && db1 == db2 && r1 == r2,
_ => expr1 == expr2,
}
}
pub(crate) fn expr_contains_null(expr: &ast::Expr) -> bool {
let mut contains_null = false;
let _ = walk_expr(expr, &mut |expr: &ast::Expr| -> Result<WalkControl> {
if let ast::Expr::Literal(ast::Literal::Null) = expr {
contains_null = true;
return Ok(WalkControl::SkipChildren);
}
Ok(WalkControl::Continue)
}); contains_null
}
pub(crate) fn type_from_name(type_name: &str) -> (Type, bool) {
let type_name = type_name.as_bytes();
if type_name.is_empty() {
return (Type::Blob, false);
}
if eq_ignore_ascii_case!(type_name, b"INTEGER") {
return (Type::Integer, true);
}
if contains_ignore_ascii_case!(type_name, b"INT") {
return (Type::Integer, false);
}
if let Some(ty) = type_name.windows(4).find_map(|s| {
match_ignore_ascii_case!(match s {
b"CHAR" | b"CLOB" | b"TEXT" => Some(Type::Text),
b"BLOB" => Some(Type::Blob),
b"REAL" | b"FLOA" | b"DOUB" => Some(Type::Real),
_ => None,
})
}) {
return (ty, false);
}
(Type::Numeric, false)
}
pub fn columns_from_create_table_body(
body: &turso_parser::ast::CreateTableBody,
) -> crate::Result<Vec<Column>> {
let CreateTableBody::ColumnsAndConstraints { columns, .. } = body else {
return Err(crate::LimboError::ParseError(
"CREATE TABLE body must contain columns and constraints".to_string(),
));
};
columns
.iter()
.map(Column::try_from)
.collect::<crate::Result<Vec<Column>>>()
}
#[derive(Debug, Default, PartialEq)]
pub struct OpenOptions<'a> {
pub authority: Option<&'a str>,
pub path: String,
pub vfs: Option<String>,
pub mode: OpenMode,
pub modeof: Option<String>,
pub cache: CacheMode,
pub immutable: bool,
pub cipher: Option<String>,
pub hexkey: Option<String>,
}
pub const MEMORY_PATH: &str = ":memory:";
#[derive(Clone, Default, Debug, Copy, PartialEq)]
pub enum OpenMode {
ReadOnly,
ReadWrite,
Memory,
#[default]
ReadWriteCreate,
}
#[derive(Debug, Default, Clone, Copy, PartialEq)]
pub enum CacheMode {
#[default]
Private,
Shared,
}
impl From<&str> for CacheMode {
fn from(s: &str) -> Self {
match s {
"private" => CacheMode::Private,
"shared" => CacheMode::Shared,
_ => CacheMode::Private,
}
}
}
impl OpenMode {
pub fn from_str(s: &str) -> Result<Self> {
let s_bytes = s.trim().as_bytes();
match_ignore_ascii_case!(match s_bytes {
b"ro" => Ok(OpenMode::ReadOnly),
b"rw" => Ok(OpenMode::ReadWrite),
b"memory" => Ok(OpenMode::Memory),
b"rwc" => Ok(OpenMode::ReadWriteCreate),
_ => Err(LimboError::InvalidArgument(format!(
"Invalid mode: '{s}'. Expected one of 'ro', 'rw', 'memory', 'rwc'"
))),
})
}
}
fn is_windows_path(path: &str) -> bool {
path.len() >= 3
&& path.chars().nth(1) == Some(':')
&& (path.chars().nth(2) == Some('/') || path.chars().nth(2) == Some('\\'))
}
fn normalize_windows_path(path: &str) -> String {
let mut normalized = path.replace("\\", "/");
while normalized.contains("//") {
normalized = normalized.replace("//", "/");
}
if normalized.len() >= 3
&& !normalized.starts_with('/')
&& normalized.chars().nth(1) == Some(':')
&& normalized.chars().nth(2) == Some('/')
{
normalized.insert(0, '/');
}
normalized
}
impl<'a> OpenOptions<'a> {
pub fn parse(uri: &'a str) -> Result<OpenOptions<'a>> {
if !uri.starts_with("file:") {
return Ok(OpenOptions {
path: uri.to_string(),
..Default::default()
});
}
let mut opts = OpenOptions::default();
let without_scheme = &uri[5..];
let (without_fragment, _) = without_scheme
.split_once('#')
.unwrap_or((without_scheme, ""));
let (without_query, query) = without_fragment
.split_once('?')
.unwrap_or((without_fragment, ""));
parse_query_params(query, &mut opts)?;
if let Some(after_slashes) = without_query.strip_prefix("//") {
let (authority, path) = after_slashes.split_once('/').unwrap_or((after_slashes, ""));
if !(authority.is_empty() || authority == "localhost") {
return Err(LimboError::InvalidArgument(format!(
"Invalid authority '{authority}'. Only '' or 'localhost' allowed."
)));
}
opts.authority = if authority.is_empty() {
None
} else {
Some(authority)
};
if is_windows_path(path) {
opts.path = normalize_windows_path(&decode_percent(path));
} else if !path.is_empty() {
opts.path = format!("/{}", decode_percent(path));
} else {
opts.path = String::new();
}
} else {
opts.path = decode_percent(without_query);
}
Ok(opts)
}
pub fn get_flags(&self) -> Result<OpenFlags> {
if self.mode != OpenMode::ReadWriteCreate && self.modeof.is_some() {
return Err(LimboError::InvalidArgument(
"modeof is not applicable without mode=rwc".to_string(),
));
}
Ok(match self.mode {
OpenMode::ReadWriteCreate => OpenFlags::Create,
OpenMode::ReadOnly => OpenFlags::ReadOnly,
_ => OpenFlags::default(),
})
}
}
fn parse_query_params(query: &str, opts: &mut OpenOptions) -> Result<()> {
for param in query.split('&') {
if let Some((key, value)) = param.split_once('=') {
let decoded_value = decode_percent(value);
match key {
"mode" => opts.mode = OpenMode::from_str(value)?,
"modeof" => opts.modeof = Some(decoded_value),
"cache" => opts.cache = decoded_value.as_str().into(),
"immutable" => opts.immutable = decoded_value == "1",
"vfs" => opts.vfs = Some(decoded_value),
"cipher" => opts.cipher = Some(decoded_value),
"hexkey" => opts.hexkey = Some(decoded_value),
_ => {}
}
}
}
Ok(())
}
pub fn decode_percent(uri: &str) -> String {
let from_hex_digit = |digit: u8| -> Option<u8> {
match digit {
b'0'..=b'9' => Some(digit - b'0'),
b'A'..=b'F' => Some(digit - b'A' + 10),
b'a'..=b'f' => Some(digit - b'a' + 10),
_ => None,
}
};
let offset = uri.chars().take_while(|&c| c != '%').count();
if offset >= uri.len() {
return uri.to_string();
}
let mut decoded: Vec<u8> = Vec::with_capacity(uri.len());
let (ascii, mut data) = uri.as_bytes().split_at(offset);
decoded.extend_from_slice(ascii);
loop {
let mut parts = data.splitn(2, |&c| c == b'%');
let non_escaped_part = parts.next().unwrap();
let rest = parts.next();
if rest.is_none() && decoded.is_empty() {
return String::from_utf8_lossy(data).to_string();
}
decoded.extend_from_slice(non_escaped_part);
match rest {
Some(rest) => match rest.get(0..2) {
Some([first, second]) => match from_hex_digit(*first) {
Some(first_val) => match from_hex_digit(*second) {
Some(second_val) => {
decoded.push((first_val << 4) | second_val);
data = &rest[2..];
}
None => {
decoded.extend_from_slice(&[b'%', *first]);
data = &rest[1..];
}
},
None => {
decoded.push(b'%');
data = rest;
}
},
_ => {
decoded.push(b'%');
decoded.extend_from_slice(rest);
break;
}
},
None => break,
}
}
String::from_utf8_lossy(&decoded).to_string()
}
pub fn trim_ascii_whitespace(s: &str) -> &str {
let bytes = s.as_bytes();
let start = bytes
.iter()
.position(|&b| !b.is_ascii_whitespace())
.unwrap_or(bytes.len());
let end = bytes
.iter()
.rposition(|&b| !b.is_ascii_whitespace())
.map(|i| i + 1)
.unwrap_or(0);
if start <= end {
&s[start..end]
} else {
""
}
}
pub fn checked_cast_text_to_numeric(text: &str, lossless: bool) -> std::result::Result<Value, ()> {
let original_len = text.trim().len();
let (kind, text) = parse_numeric_str(text)?;
if original_len != text.len() && lossless {
return Err(());
}
match kind {
ValueType::Integer => match text.parse::<i64>() {
Ok(i) => Ok(Value::from_i64(i)),
Err(e) => {
if matches!(
e.kind(),
std::num::IntErrorKind::PosOverflow | std::num::IntErrorKind::NegOverflow
) {
let value = text.parse::<f64>().unwrap_or_default();
let factor = 10f64.powi(15 - value.abs().log10().ceil() as i32);
Ok(Value::from_f64((value * factor).round() / factor))
} else {
Err(())
}
}
},
ValueType::Float => {
let value = text.parse::<f64>().unwrap_or(0.0);
Ok(real_to_numeric_value(value))
}
_ => unreachable!(),
}
}
fn real_to_numeric_value(value: f64) -> Value {
const INT51_MAX: i64 = 1 << 51;
if value == 0.0 {
return Value::from_i64(0);
}
if value.is_finite() {
let i = value as i64;
if (i as f64) == value && (-INT51_MAX..INT51_MAX).contains(&i) {
return Value::from_i64(i);
}
}
Value::from_f64(value)
}
fn parse_numeric_str(text: &str) -> Result<(ValueType, &str), ()> {
let text = text.trim();
let bytes = text.as_bytes();
if matches!(
bytes,
[] | [b'e', ..] | [b'E', ..] | [b'.', b'e' | b'E', ..]
) {
return Err(());
}
let mut end = 0;
let mut has_decimal = false;
let mut has_exponent = false;
if bytes[0] == b'-' || bytes[0] == b'+' {
end = 1;
}
while end < bytes.len() {
match bytes[end] {
b'0'..=b'9' => end += 1,
b'.' if !has_decimal && !has_exponent => {
has_decimal = true;
end += 1;
}
b'e' | b'E' if !has_exponent => {
has_exponent = true;
end += 1;
if end < bytes.len() && (bytes[end] == b'+' || bytes[end] == b'-') {
end += 1;
}
}
_ => break,
}
}
if end == 0 || (end == 1 && (bytes[0] == b'-' || bytes[0] == b'+')) {
return Err(());
}
let last = bytes[end - 1];
if last.eq_ignore_ascii_case(&b'e') {
return Ok((ValueType::Float, &text[0..end - 1]));
} else if has_exponent && (last == b'-' || last == b'+') {
return Ok((ValueType::Float, &text[0..end - 2]));
}
Ok((
if !has_decimal && !has_exponent {
ValueType::Integer
} else {
ValueType::Float
},
&text[..end],
))
}
pub fn cast_real_to_integer(float: f64) -> std::result::Result<i64, ()> {
if !float.is_finite() || float.trunc() != float {
return Err(());
}
let int_val = float as i64;
if int_val == i64::MIN || int_val == i64::MAX {
return Err(());
}
if (int_val as f64) != float {
return Err(());
}
Ok(int_val)
}
pub fn parse_numeric_literal(text: &str) -> Result<Value> {
let text = if text.contains('_') {
std::borrow::Cow::Owned(text.replace('_', ""))
} else {
std::borrow::Cow::Borrowed(text)
};
if text.starts_with("0x") || text.starts_with("0X") {
let value = u64::from_str_radix(&text[2..], 16)? as i64;
return Ok(Value::from_i64(value));
} else if text.starts_with("-0x") || text.starts_with("-0X") {
let value = u64::from_str_radix(&text[3..], 16)? as i64;
if value == i64::MIN {
return Err(LimboError::IntegerOverflow);
}
return Ok(Value::from_i64(-value));
}
if let Ok(int_value) = text.parse::<i64>() {
return Ok(Value::from_i64(int_value));
}
let Some(StrToF64::Fractional(float) | StrToF64::Decimal(float)) =
crate::numeric::str_to_f64(text)
else {
unreachable!();
};
Ok(Value::Numeric(crate::numeric::Numeric::Float(float)))
}
pub fn parse_signed_number(expr: &Expr) -> Result<Value> {
match expr {
Expr::Literal(Literal::Numeric(num)) => parse_numeric_literal(num),
Expr::Unary(op, expr) => match (op, expr.as_ref()) {
(UnaryOperator::Negative, Expr::Literal(Literal::Numeric(num))) => {
let data = "-".to_owned() + &num.to_string();
parse_numeric_literal(&data)
}
(UnaryOperator::Positive, Expr::Literal(Literal::Numeric(num))) => {
parse_numeric_literal(num)
}
_ => Err(LimboError::InvalidArgument(
"signed-number must follow the format: ([+|-] numeric-literal)".to_string(),
)),
},
_ => Err(LimboError::InvalidArgument(
"signed-number must follow the format: ([+|-] numeric-literal)".to_string(),
)),
}
}
pub fn parse_string(expr: &Expr) -> Result<String> {
match expr {
Expr::Name(name) if name.quoted_with('\'') => Ok(name.as_str().to_string()),
_ => Err(LimboError::InvalidArgument(format!(
"string parameter expected, got {expr:?} instead"
))),
}
}
#[allow(unused)]
pub fn parse_pragma_bool(expr: &Expr) -> Result<bool> {
const TRUE_VALUES: &[&str] = &["yes", "true", "on"];
const FALSE_VALUES: &[&str] = &["no", "false", "off"];
if let Ok(number) = parse_signed_number(expr) {
if let Value::Numeric(crate::numeric::Numeric::Integer(x @ (0 | 1))) = number {
return Ok(x != 0);
}
} else if let Expr::Name(name) = expr {
let ident = normalize_ident(name.as_str());
if TRUE_VALUES.contains(&ident.as_str()) {
return Ok(true);
}
if FALSE_VALUES.contains(&ident.as_str()) {
return Ok(false);
}
}
Err(LimboError::InvalidArgument(
"boolean pragma value must be either 0|1 integer or yes|true|on|no|false|off token"
.to_string(),
))
}
pub fn extract_column_name_from_expr(expr: impl AsRef<ast::Expr>) -> Option<String> {
match expr.as_ref() {
ast::Expr::Id(name) => Some(name.as_str().to_string()),
ast::Expr::DoublyQualified(_, _, name) | ast::Expr::Qualified(_, name) => {
Some(normalize_ident(name.as_str()))
}
_ => None,
}
}
#[derive(Debug, Clone)]
pub struct ViewTable {
pub name: String,
pub db_name: Option<String>,
pub alias: Option<String>,
}
#[derive(Debug, Clone)]
pub struct ViewColumn {
pub table_index: usize,
pub column: Column,
}
#[derive(Debug, Clone)]
pub struct ViewColumnSchema {
pub tables: Vec<ViewTable>,
pub columns: Vec<ViewColumn>,
}
impl ViewColumnSchema {
pub fn flat_columns(&self) -> Vec<Column> {
self.columns.iter().map(|vc| vc.column.clone()).collect()
}
pub fn table_columns(&self, table_index: usize) -> Vec<Column> {
self.columns
.iter()
.filter(|vc| vc.table_index == table_index)
.map(|vc| vc.column.clone())
.collect()
}
}
pub fn walk_select_expressions<F>(select: &ast::Select, func: &mut F) -> Result<()>
where
F: FnMut(&ast::Expr) -> Result<WalkControl>,
{
walk_select_expressions_inner(select, func)
}
fn walk_select_expressions_inner<F>(select: &ast::Select, func: &mut F) -> Result<()>
where
F: FnMut(&ast::Expr) -> Result<WalkControl>,
{
if let Some(with_clause) = &select.with {
for cte in &with_clause.ctes {
walk_select_expressions_inner(&cte.select, func)?;
}
}
walk_one_select_expressions(&select.body.select, func)?;
for compound in &select.body.compounds {
walk_one_select_expressions(&compound.select, func)?;
}
for sorted_col in &select.order_by {
walk_expr_with_subqueries(&sorted_col.expr, func)?;
}
if let Some(limit) = &select.limit {
walk_expr_with_subqueries(&limit.expr, func)?;
if let Some(offset) = &limit.offset {
walk_expr_with_subqueries(offset, func)?;
}
}
Ok(())
}
fn walk_one_select_expressions<F>(one_select: &ast::OneSelect, func: &mut F) -> Result<()>
where
F: FnMut(&ast::Expr) -> Result<WalkControl>,
{
match one_select {
ast::OneSelect::Select {
columns,
from,
where_clause,
group_by,
window_clause,
..
} => {
for col in columns {
if let ast::ResultColumn::Expr(expr, _) = col {
walk_expr_with_subqueries(expr, func)?;
}
}
if let Some(from_clause) = from {
walk_from_clause_expressions(from_clause, func)?;
}
if let Some(where_expr) = where_clause {
walk_expr_with_subqueries(where_expr, func)?;
}
if let Some(group_by) = group_by {
for expr in &group_by.exprs {
walk_expr_with_subqueries(expr, func)?;
}
if let Some(having_expr) = &group_by.having {
walk_expr_with_subqueries(having_expr, func)?;
}
}
for window_def in window_clause {
walk_window_expressions(&window_def.window, func)?;
}
}
ast::OneSelect::Values(values) => {
for row in values {
for expr in row {
walk_expr_with_subqueries(expr, func)?;
}
}
}
}
Ok(())
}
fn walk_from_clause_expressions<F>(from_clause: &ast::FromClause, func: &mut F) -> Result<()>
where
F: FnMut(&ast::Expr) -> Result<WalkControl>,
{
walk_select_table_expressions(&from_clause.select, func)?;
for join in &from_clause.joins {
walk_select_table_expressions(&join.table, func)?;
if let Some(ast::JoinConstraint::On(expr)) = &join.constraint {
walk_expr_with_subqueries(expr, func)?;
}
}
Ok(())
}
fn walk_select_table_expressions<F>(select_table: &ast::SelectTable, func: &mut F) -> Result<()>
where
F: FnMut(&ast::Expr) -> Result<WalkControl>,
{
match select_table {
ast::SelectTable::Select(select, _) => walk_select_expressions_inner(select, func),
ast::SelectTable::Sub(from_clause, _) => walk_from_clause_expressions(from_clause, func),
ast::SelectTable::TableCall(_, args, _) => {
for arg in args {
walk_expr_with_subqueries(arg, func)?;
}
Ok(())
}
ast::SelectTable::Table(_, _, _) => Ok(()),
}
}
fn walk_window_expressions<F>(window: &ast::Window, func: &mut F) -> Result<()>
where
F: FnMut(&ast::Expr) -> Result<WalkControl>,
{
for expr in &window.partition_by {
walk_expr_with_subqueries(expr, func)?;
}
for sorted_col in &window.order_by {
walk_expr_with_subqueries(&sorted_col.expr, func)?;
}
if let Some(frame_clause) = &window.frame_clause {
walk_frame_bound_expressions(&frame_clause.start, func)?;
if let Some(end_bound) = &frame_clause.end {
walk_frame_bound_expressions(end_bound, func)?;
}
}
Ok(())
}
fn walk_frame_bound_expressions<F>(bound: &ast::FrameBound, func: &mut F) -> Result<()>
where
F: FnMut(&ast::Expr) -> Result<WalkControl>,
{
match bound {
ast::FrameBound::Following(expr) | ast::FrameBound::Preceding(expr) => {
walk_expr_with_subqueries(expr, func)
}
ast::FrameBound::CurrentRow
| ast::FrameBound::UnboundedFollowing
| ast::FrameBound::UnboundedPreceding => Ok(()),
}
}
pub fn walk_expr_with_subqueries<F>(expr: &ast::Expr, func: &mut F) -> Result<()>
where
F: FnMut(&ast::Expr) -> Result<WalkControl>,
{
walk_expr(expr, &mut |e| {
let control = func(e)?;
if matches!(control, WalkControl::Continue) {
match e {
ast::Expr::Subquery(select) | ast::Expr::Exists(select) => {
walk_select_expressions_inner(select, func)?;
}
ast::Expr::InSelect { rhs, .. } => {
walk_select_expressions_inner(rhs, func)?;
}
_ => {}
}
}
Ok(control)
})?;
Ok(())
}
fn validate_no_cross_db_references(
select_stmt: &ast::Select,
view_db_name: Option<&ast::Name>,
) -> Result<()> {
if let Some(with_clause) = &select_stmt.with {
for cte in &with_clause.ctes {
validate_no_cross_db_references(&cte.select, view_db_name)?;
}
}
validate_one_select_no_cross_db(&select_stmt.body.select, view_db_name)?;
for compound in &select_stmt.body.compounds {
validate_one_select_no_cross_db(&compound.select, view_db_name)?;
}
Ok(())
}
fn validate_one_select_no_cross_db(
one_select: &ast::OneSelect,
view_db_name: Option<&ast::Name>,
) -> Result<()> {
match one_select {
ast::OneSelect::Select { from, .. } => {
if let Some(from_clause) = from {
validate_from_clause_no_cross_db(from_clause, view_db_name)?;
}
}
ast::OneSelect::Values(_) => {}
}
Ok(())
}
fn validate_from_clause_no_cross_db(
from_clause: &ast::FromClause,
view_db_name: Option<&ast::Name>,
) -> Result<()> {
validate_select_table_no_cross_db(&from_clause.select, view_db_name)?;
for join in &from_clause.joins {
validate_select_table_no_cross_db(&join.table, view_db_name)?;
}
Ok(())
}
fn reject_cross_db_qualified_name(
qualified_name: &ast::QualifiedName,
view_db_name: Option<&ast::Name>,
) -> Result<()> {
if let Some(table_db_name) = &qualified_name.db_name {
let is_cross_db = match view_db_name {
Some(view_db) => !view_db
.as_str()
.eq_ignore_ascii_case(table_db_name.as_str()),
None => !table_db_name.as_str().eq_ignore_ascii_case("main"),
};
if is_cross_db {
return Err(crate::LimboError::ParseError(format!(
"view cannot reference table in attached database: {qualified_name}"
)));
}
}
Ok(())
}
fn validate_select_table_no_cross_db(
select_table: &ast::SelectTable,
view_db_name: Option<&ast::Name>,
) -> Result<()> {
match select_table {
ast::SelectTable::Table(name, _, _) | ast::SelectTable::TableCall(name, _, _) => {
reject_cross_db_qualified_name(name, view_db_name)?;
}
ast::SelectTable::Select(select, _) => {
validate_no_cross_db_references(select, view_db_name)?;
}
ast::SelectTable::Sub(from_clause, _) => {
validate_from_clause_no_cross_db(from_clause, view_db_name)?;
}
}
Ok(())
}
pub fn validate_select_for_unsupported_features(select_stmt: &ast::Select) -> Result<()> {
walk_select_expressions(select_stmt, &mut |expr| {
if let ast::Expr::FunctionCall { order_by, .. } = expr {
if !order_by.is_empty() {
crate::bail_parse_error!(
"ORDER BY clause is not supported yet in aggregate functions"
);
}
}
Ok(WalkControl::Continue)
})
}
pub fn validate_select_for_views(
select_stmt: &ast::Select,
view_db_name: Option<&ast::Name>,
) -> Result<()> {
validate_select_for_unsupported_features(select_stmt)?;
validate_no_cross_db_references(select_stmt, view_db_name)?;
walk_select_expressions(select_stmt, &mut |expr| {
match expr {
ast::Expr::Subquery(subquery_select) | ast::Expr::Exists(subquery_select) => {
validate_no_cross_db_references(subquery_select, view_db_name)?;
}
ast::Expr::InSelect { rhs, .. } => {
validate_no_cross_db_references(rhs, view_db_name)?;
}
_ => {}
}
Ok(WalkControl::Continue)
})?;
Ok(())
}
pub fn extract_view_columns(
select_stmt: &ast::Select,
schema: &Schema,
) -> Result<ViewColumnSchema> {
let mut tables = Vec::new();
let mut columns = Vec::new();
let mut column_name_counts: HashMap<String, usize> = HashMap::default();
if let ast::OneSelect::Select {
ref from,
columns: select_columns,
..
} = &select_stmt.body.select
{
if let Some(from) = from {
match from.select.as_ref() {
ast::SelectTable::Table(qualified_name, alias, _) => {
let table_name = normalize_ident(qualified_name.name.as_str());
let db_name = qualified_name
.db_name
.as_ref()
.map(|db| normalize_ident(db.as_str()));
tables.push(ViewTable {
name: table_name,
db_name,
alias: alias.as_ref().map(|a| normalize_ident(a.name().as_str())),
});
}
_ => {
}
}
for join in &from.joins {
match join.table.as_ref() {
ast::SelectTable::Table(qualified_name, alias, _) => {
let table_name = normalize_ident(qualified_name.name.as_str());
let db_name = qualified_name
.db_name
.as_ref()
.map(|db| normalize_ident(db.as_str()));
tables.push(ViewTable {
name: table_name,
db_name,
alias: alias.as_ref().map(|a| normalize_ident(a.name().as_str())),
});
}
_ => {
}
}
}
}
let find_table_index = |name: &str| -> Option<usize> {
tables.iter().position(|t| {
t.name.eq_ignore_ascii_case(name)
|| t.alias
.as_ref()
.is_some_and(|a| a.eq_ignore_ascii_case(name))
})
};
for result_col in select_columns.iter() {
match result_col {
ast::ResultColumn::Expr(expr, alias) => {
let table_index = match expr.as_ref() {
ast::Expr::Qualified(table_ref, _col_name) => {
find_table_index(table_ref.as_str())
}
ast::Expr::Id(_col_name) => {
if !tables.is_empty() {
Some(0)
} else {
None
}
}
_ => None, };
let col_name = alias
.as_ref()
.filter(|a| !matches!(a, ast::As::ImplicitColumnName(_)))
.map(|a| a.name().as_str().to_string())
.or_else(|| extract_column_name_from_expr(expr))
.unwrap_or_else(|| {
expr.to_string()
});
columns.push(ViewColumn {
table_index: table_index.unwrap_or(usize::MAX),
column: Column::new_default_text(Some(col_name), "TEXT".to_string(), None),
});
}
ast::ResultColumn::Star => {
for (table_idx, table) in tables.iter().enumerate() {
if let Some(table_obj) = schema.get_table(&table.name) {
for table_column in table_obj.columns() {
let col_name =
table_column.name.clone().unwrap_or_else(|| "?".to_string());
let final_name =
if let Some(count) = column_name_counts.get_mut(&col_name) {
*count += 1;
format!("{}:{}", col_name, *count - 1)
} else {
column_name_counts.insert(col_name.clone(), 1);
col_name.clone()
};
columns.push(ViewColumn {
table_index: table_idx,
column: Column::new(
Some(final_name),
table_column.ty_str.clone(),
None,
None,
table_column.ty(),
table_column.collation_opt(),
ColDef::default(),
),
});
}
}
}
if tables.is_empty() {
columns.push(ViewColumn {
table_index: usize::MAX,
column: Column::new_default_text(
Some("*".to_string()),
"TEXT".to_string(),
None,
),
});
}
}
ast::ResultColumn::TableStar(table_ref) => {
let table_name_str = normalize_ident(table_ref.as_str());
if let Some(table_idx) = find_table_index(&table_name_str) {
if let Some(table) = schema.get_table(&tables[table_idx].name) {
for table_column in table.columns() {
let col_name =
table_column.name.clone().unwrap_or_else(|| "?".to_string());
let final_name =
if let Some(count) = column_name_counts.get_mut(&col_name) {
*count += 1;
format!("{}:{}", col_name, *count - 1)
} else {
column_name_counts.insert(col_name.clone(), 1);
col_name.clone()
};
columns.push(ViewColumn {
table_index: table_idx,
column: Column::new(
Some(final_name),
table_column.ty_str.clone(),
None,
None,
table_column.ty(),
table_column.collation_opt(),
ColDef::default(),
),
});
}
} else {
columns.push(ViewColumn {
table_index: usize::MAX,
column: Column::new_default_text(
Some(format!("{table_name_str}.*")),
"TEXT".to_string(),
None,
),
});
}
}
}
}
}
}
Ok(ViewColumnSchema { tables, columns })
}
pub fn rewrite_fk_parent_cols_if_self_ref(
clause: &mut ast::ForeignKeyClause,
table: &str,
from: &str,
to: &str,
) {
if clause.tbl_name.as_str().eq_ignore_ascii_case(table) {
for c in &mut clause.columns {
if c.col_name.as_str().eq_ignore_ascii_case(from) {
c.col_name = ast::Name::exact(to.to_owned());
}
}
}
}
pub fn check_expr_references_column(expr: &ast::Expr, col_name_normalized: &str) -> bool {
let mut found = false;
let _ = walk_expr(expr, &mut |e| {
if found {
return Ok(WalkControl::SkipChildren);
}
match e {
ast::Expr::Id(name) | ast::Expr::Name(name) => {
if name.as_str().eq_ignore_ascii_case(col_name_normalized) {
found = true;
return Ok(WalkControl::SkipChildren);
}
}
ast::Expr::Qualified(_, col) | ast::Expr::DoublyQualified(_, _, col) => {
if col.as_str().eq_ignore_ascii_case(col_name_normalized) {
found = true;
return Ok(WalkControl::SkipChildren);
}
}
_ => {}
}
Ok(WalkControl::Continue)
});
found
}
pub fn rename_identifiers(expr: &mut ast::Expr, from: &str, to: &str) {
let _ = walk_expr_mut(
expr,
&mut |e: &mut ast::Expr| -> crate::Result<WalkControl> {
match e {
ast::Expr::Id(ref name) | ast::Expr::Name(ref name)
if name.as_str().eq_ignore_ascii_case(from) =>
{
*e = ast::Expr::Id(ast::Name::exact(to.to_owned()));
}
ast::Expr::Qualified(ref tbl, ref col_name)
if col_name.as_str().eq_ignore_ascii_case(from) =>
{
let tbl = tbl.clone();
*e = ast::Expr::Qualified(tbl, ast::Name::exact(to.to_owned()));
}
_ => {}
}
Ok(WalkControl::Continue)
},
);
}
pub fn rename_identifiers_scoped(
expr: &mut ast::Expr,
target_table: &str,
trigger_table: &str,
from: &str,
to: &str,
) {
rename_identifiers_scoped_inner(expr, target_table, trigger_table, from, to, true, None);
}
pub fn rename_identifiers_scoped_when_clause(
expr: &mut ast::Expr,
target_table: &str,
trigger_table: &str,
from: &str,
to: &str,
) {
rename_identifiers_scoped_inner(expr, target_table, trigger_table, from, to, false, None);
}
fn rename_identifiers_scoped_inner(
expr: &mut ast::Expr,
target_table: &str,
trigger_table: &str,
from: &str,
to: &str,
rename_unqualified: bool,
target_qualifiers: Option<&[String]>,
) {
let is_renaming_trigger_table = target_table.eq_ignore_ascii_case(trigger_table);
let _ = walk_expr_mut(
expr,
&mut |e: &mut ast::Expr| -> crate::Result<WalkControl> {
match e {
ast::Expr::Subquery(select) | ast::Expr::Exists(select) => {
let mut quals = target_qualifiers.unwrap_or(&[]).to_vec();
rewrite_select_column_refs_scoped(
select,
target_table,
trigger_table,
from,
to,
&mut quals,
);
}
ast::Expr::InSelect { rhs, .. } => {
let mut quals = target_qualifiers.unwrap_or(&[]).to_vec();
rewrite_select_column_refs_scoped(
rhs,
target_table,
trigger_table,
from,
to,
&mut quals,
);
}
ast::Expr::Id(ref name) | ast::Expr::Name(ref name)
if rename_unqualified && name.as_str().eq_ignore_ascii_case(from) =>
{
*e = ast::Expr::Id(ast::Name::exact(to.to_owned()));
}
ast::Expr::Qualified(ref tbl, ref col_name)
if col_name.as_str().eq_ignore_ascii_case(from) =>
{
let tbl_norm = normalize_ident(tbl.as_str());
let should_rename = if tbl_norm == "new" || tbl_norm == "old" {
is_renaming_trigger_table
} else {
target_qualifiers.is_some_and(|qualifiers| qualifiers.contains(&tbl_norm))
|| tbl_norm.eq_ignore_ascii_case(target_table)
};
if should_rename {
let tbl = tbl.clone();
*e = ast::Expr::Qualified(tbl, ast::Name::exact(to.to_owned()));
}
}
_ => {}
}
Ok(WalkControl::Continue)
},
);
}
mod rename_column_view {
use super::*;
#[derive(Debug, Clone)]
pub struct RewrittenView {
pub sql: String,
pub select_stmt: ast::Select,
pub columns: Vec<Column>,
}
pub fn rewrite_view_sql_for_column_rename(
view_sql: &str,
schema: &Schema,
target_table: &str,
target_db_name: &str,
old_column: &str,
new_column: &str,
) -> Result<Option<RewrittenView>> {
let mut visiting_views = HashSet::default();
rewrite_view_sql_for_column_rename_inner(
view_sql,
schema,
target_table,
target_db_name,
old_column,
new_column,
&mut visiting_views,
)
}
fn rewrite_view_sql_for_column_rename_inner(
view_sql: &str,
schema: &Schema,
target_table: &str,
target_db_name: &str,
old_column: &str,
new_column: &str,
visiting_views: &mut HashSet<String>,
) -> Result<Option<RewrittenView>> {
let mut parser = Parser::new(view_sql.as_bytes());
let cmd = parser
.next_cmd()
.map_err(|e| LimboError::ParseError(format!("failed to parse view SQL: {e}")))?;
let Some(ast::Cmd::Stmt(ast::Stmt::CreateView {
temporary,
if_not_exists,
view_name,
columns: view_columns,
mut select,
})) = cmd
else {
return Ok(None);
};
let current_view_name = normalize_ident(view_name.name.as_str());
if !visiting_views.insert(current_view_name.clone()) {
return Err(LimboError::ParseError(format!(
"view {current_view_name} is circularly defined"
)));
}
let rewrite_result = (|| -> Result<Option<RewrittenView>> {
let original_select = select.clone();
let original_columns =
view_columns_from_select(&original_select, schema, &view_columns)?;
let ctx =
ViewRewriteCtx::new(schema, target_table, target_db_name, old_column, new_column);
let sql_changed =
rewrite_view_select_for_column_rename(&mut select, &ctx, &[], visiting_views)?;
let view_column_schema = extract_view_columns(&select, schema)?;
let mut final_columns = apply_view_column_rename(view_column_schema, &ctx);
for (i, indexed_col) in view_columns.iter().enumerate() {
if let Some(col) = final_columns.get_mut(i) {
col.name = Some(indexed_col.col_name.to_string());
}
}
let columns_changed = !columns_equivalent(&original_columns, &final_columns);
if !sql_changed && !columns_changed {
return Ok(None);
}
let new_sql = if sql_changed {
let new_stmt = ast::Stmt::CreateView {
temporary,
if_not_exists,
view_name,
columns: view_columns,
select: select.clone(),
};
new_stmt.to_string()
} else {
view_sql.to_string()
};
Ok(Some(RewrittenView {
sql: new_sql,
select_stmt: select,
columns: final_columns,
}))
})();
visiting_views.remove(¤t_view_name);
rewrite_result
}
fn apply_view_column_rename(
view_columns: ViewColumnSchema,
ctx: &ViewRewriteCtx,
) -> Vec<Column> {
let target_norm = ctx.target_table_norm.as_str();
let mut columns = view_columns.columns;
for view_column in &mut columns {
if view_column.table_index == usize::MAX {
continue;
}
let table = &view_columns.tables[view_column.table_index];
if table_name_matches_target(
&table.name,
table.db_name.as_deref(),
target_norm,
&ctx.target_db_norm,
) {
if let Some(ref mut name) = view_column.column.name {
if name.as_str().eq_ignore_ascii_case(ctx.old_column) {
*name = ctx.new_column.to_string();
}
}
}
}
columns.into_iter().map(|vc| vc.column).collect()
}
fn view_columns_from_select(
select: &ast::Select,
schema: &Schema,
explicit: &[ast::IndexedColumn],
) -> Result<Vec<Column>> {
let view_column_schema = extract_view_columns(select, schema)?;
let mut columns = view_column_schema.flat_columns();
for (i, indexed_col) in explicit.iter().enumerate() {
if let Some(col) = columns.get_mut(i) {
col.name = Some(indexed_col.col_name.to_string());
}
}
Ok(columns)
}
fn columns_equivalent(left: &[Column], right: &[Column]) -> bool {
if left.len() != right.len() {
return false;
}
left.iter().zip(right.iter()).all(|(l, r)| {
let l_name = l.name.as_deref().unwrap_or("");
let r_name = r.name.as_deref().unwrap_or("");
l_name.eq_ignore_ascii_case(r_name)
})
}
#[derive(Clone)]
struct ViewSourceInfo {
qualifiers: Vec<String>,
columns_before: HashSet<String>,
rename_map: HashMap<String, String>,
is_target_table: bool,
db_name: Option<String>,
}
impl ViewSourceInfo {
fn matches_qualifier(&self, qualifier: &str) -> bool {
self.qualifiers.iter().any(|q| q == qualifier)
}
}
fn alias_name(alias: &ast::As) -> &str {
alias.name().as_str()
}
#[derive(Clone)]
struct CteInfo {
columns_before: HashSet<String>,
rename_map: HashMap<String, String>,
}
struct ViewRewriteCtx<'a> {
schema: &'a Schema,
target_table: &'a str,
target_table_norm: String,
target_db_norm: String,
old_column: &'a str,
old_column_norm: String,
new_column: &'a str,
}
impl<'a> ViewRewriteCtx<'a> {
fn new(
schema: &'a Schema,
target_table: &'a str,
target_db_name: &'a str,
old_column: &'a str,
new_column: &'a str,
) -> Self {
Self {
schema,
target_table,
target_table_norm: normalize_ident(target_table),
target_db_norm: normalize_ident(target_db_name),
old_column,
old_column_norm: normalize_ident(old_column),
new_column,
}
}
}
fn rewrite_view_select_for_column_rename(
select: &mut ast::Select,
ctx: &ViewRewriteCtx,
outer_scopes: &[&[ViewSourceInfo]],
visiting_views: &mut HashSet<String>,
) -> Result<bool> {
let mut changed = false;
let mut ctes: HashMap<String, CteInfo> = HashMap::default();
if let Some(ref mut with_clause) = select.with {
for cte in &mut with_clause.ctes {
let mut before_cols = select_output_columns(&cte.select, ctx, false)?;
apply_explicit_column_names(&mut before_cols, &cte.columns);
let cte_changed = rewrite_view_select_for_column_rename(
&mut cte.select,
ctx,
&[],
visiting_views,
)?;
changed |= cte_changed;
let mut after_cols = select_output_columns(&cte.select, ctx, true)?;
apply_explicit_column_names(&mut after_cols, &cte.columns);
let rename_map = build_rename_map(&before_cols, &after_cols, &ctx.old_column_norm);
ctes.insert(
normalize_ident(cte.tbl_name.as_str()),
CteInfo {
columns_before: before_cols
.into_iter()
.map(|c| normalize_ident(&c))
.collect(),
rename_map,
},
);
}
}
let mut scope_sources = rewrite_one_select_for_column_rename(
&mut select.body.select,
ctx,
&ctes,
outer_scopes,
&mut changed,
visiting_views,
)?;
for compound in &mut select.body.compounds {
let compound_sources = rewrite_one_select_for_column_rename(
&mut compound.select,
ctx,
&ctes,
outer_scopes,
&mut changed,
visiting_views,
)?;
if scope_sources.is_none() {
scope_sources = compound_sources;
}
}
if let Some(ref sources) = scope_sources {
for sorted_col in &mut select.order_by {
rewrite_expr_in_scope(
&mut sorted_col.expr,
sources,
outer_scopes,
ctx,
&mut changed,
visiting_views,
)?;
}
if let Some(ref mut limit) = select.limit {
rewrite_expr_in_scope(
&mut limit.expr,
sources,
outer_scopes,
ctx,
&mut changed,
visiting_views,
)?;
if let Some(ref mut offset) = limit.offset {
rewrite_expr_in_scope(
offset,
sources,
outer_scopes,
ctx,
&mut changed,
visiting_views,
)?;
}
}
}
Ok(changed)
}
fn rewrite_one_select_for_column_rename(
one_select: &mut ast::OneSelect,
ctx: &ViewRewriteCtx,
ctes: &HashMap<String, CteInfo>,
outer_scopes: &[&[ViewSourceInfo]],
changed: &mut bool,
visiting_views: &mut HashSet<String>,
) -> Result<Option<Vec<ViewSourceInfo>>> {
match one_select {
ast::OneSelect::Select {
columns,
from,
where_clause,
group_by,
window_clause,
..
} => {
let sources = if let Some(ref mut from_clause) = from {
rewrite_from_clause_for_column_rename(
from_clause,
ctx,
ctes,
outer_scopes,
changed,
visiting_views,
)?
} else {
Vec::new()
};
for col in columns {
if let ast::ResultColumn::Expr(expr, _) = col {
rewrite_expr_in_scope(
expr,
&sources,
outer_scopes,
ctx,
changed,
visiting_views,
)?;
}
}
if let Some(ref mut where_expr) = where_clause {
rewrite_expr_in_scope(
where_expr,
&sources,
outer_scopes,
ctx,
changed,
visiting_views,
)?;
}
if let Some(ref mut group_by) = group_by {
for expr in &mut group_by.exprs {
rewrite_expr_in_scope(
expr,
&sources,
outer_scopes,
ctx,
changed,
visiting_views,
)?;
}
if let Some(ref mut having_expr) = group_by.having {
rewrite_expr_in_scope(
having_expr,
&sources,
outer_scopes,
ctx,
changed,
visiting_views,
)?;
}
}
for window_def in window_clause {
for expr in &mut window_def.window.partition_by {
rewrite_expr_in_scope(
expr,
&sources,
outer_scopes,
ctx,
changed,
visiting_views,
)?;
}
for sorted in &mut window_def.window.order_by {
rewrite_expr_in_scope(
&mut sorted.expr,
&sources,
outer_scopes,
ctx,
changed,
visiting_views,
)?;
}
}
Ok(Some(sources))
}
ast::OneSelect::Values(values) => {
for row in values {
for expr in row {
rewrite_expr_in_scope(
expr,
&[],
outer_scopes,
ctx,
changed,
visiting_views,
)?;
}
}
Ok(None)
}
}
}
fn rewrite_from_clause_for_column_rename(
from_clause: &mut ast::FromClause,
ctx: &ViewRewriteCtx,
ctes: &HashMap<String, CteInfo>,
outer_scopes: &[&[ViewSourceInfo]],
changed: &mut bool,
visiting_views: &mut HashSet<String>,
) -> Result<Vec<ViewSourceInfo>> {
let mut sources = Vec::new();
let first_source = rewrite_select_table_for_column_rename(
&mut from_clause.select,
&[],
ctx,
ctes,
outer_scopes,
changed,
visiting_views,
)?;
sources.push(first_source);
for join in &mut from_clause.joins {
let right_source = rewrite_select_table_for_column_rename(
&mut join.table,
&sources,
ctx,
ctes,
outer_scopes,
changed,
visiting_views,
)?;
sources.push(right_source);
let (right_source, left_sources) = sources
.split_last()
.expect("sources should include the right-hand side join source");
if let Some(ref mut constraint) = join.constraint {
match constraint {
ast::JoinConstraint::On(expr) => {
rewrite_expr_in_scope(
expr,
&sources,
outer_scopes,
ctx,
changed,
visiting_views,
)?;
}
ast::JoinConstraint::Using(cols) => {
*changed |= rewrite_using_columns(
cols,
left_sources,
right_source,
&ctx.old_column_norm,
ctx.new_column,
);
}
}
}
}
Ok(sources)
}
fn rewrite_select_table_for_column_rename(
select_table: &mut ast::SelectTable,
visible_sources: &[ViewSourceInfo],
ctx: &ViewRewriteCtx,
ctes: &HashMap<String, CteInfo>,
outer_scopes: &[&[ViewSourceInfo]],
changed: &mut bool,
visiting_views: &mut HashSet<String>,
) -> Result<ViewSourceInfo> {
match select_table {
ast::SelectTable::Table(tbl_name, alias, _) => {
let table_name_norm = normalize_ident(tbl_name.name.as_str());
let table_db_norm = tbl_name
.db_name
.as_ref()
.map(|db| normalize_ident(db.as_str()));
let mut qualifiers = Vec::new();
qualifiers.push(table_name_norm.clone());
if let Some(ref alias) = alias {
qualifiers.push(normalize_ident(alias_name(alias)));
}
if table_db_norm.is_none() {
if let Some(cte) = ctes.get(&table_name_norm) {
return Ok(ViewSourceInfo {
qualifiers,
columns_before: cte.columns_before.clone(),
rename_map: cte.rename_map.clone(),
is_target_table: false,
db_name: None,
});
}
}
let is_local = table_db_norm
.as_deref()
.is_none_or(|db| db == ctx.target_db_norm);
if is_local {
if let Some(view) = ctx.schema.views.get(&table_name_norm) {
let columns_before = view
.columns
.iter()
.filter_map(|col| col.name.clone())
.map(|name| normalize_ident(&name))
.collect();
let mut rename_map = HashMap::default();
if let Some(rewritten) = rewrite_view_sql_for_column_rename_inner(
&view.sql,
ctx.schema,
ctx.target_table,
&ctx.target_db_norm,
ctx.old_column,
ctx.new_column,
visiting_views,
)? {
rename_map = build_rename_map_from_columns(
&view.columns,
&rewritten.columns,
&ctx.old_column_norm,
);
}
return Ok(ViewSourceInfo {
qualifiers,
columns_before,
rename_map,
is_target_table: false,
db_name: table_db_norm,
});
}
}
let is_target = table_name_matches_target(
&table_name_norm,
table_db_norm.as_deref(),
&ctx.target_table_norm,
&ctx.target_db_norm,
);
let columns_before = if is_local {
table_source_columns(ctx.schema, &table_name_norm)
.unwrap_or_default()
.into_iter()
.map(|c| normalize_ident(&c))
.collect()
} else {
HashSet::default()
};
Ok(ViewSourceInfo {
qualifiers,
columns_before,
rename_map: HashMap::default(),
is_target_table: is_target,
db_name: table_db_norm,
})
}
ast::SelectTable::Select(select, alias) => {
let before_cols = select_output_columns(select, ctx, false)?;
*changed |=
rewrite_view_select_for_column_rename(select, ctx, &[], visiting_views)?;
let after_cols = select_output_columns(select, ctx, true)?;
let rename_map = build_rename_map(&before_cols, &after_cols, &ctx.old_column_norm);
let qualifiers = alias
.as_ref()
.map(|alias| vec![normalize_ident(alias_name(alias))])
.unwrap_or_default();
Ok(ViewSourceInfo {
qualifiers,
columns_before: before_cols
.into_iter()
.map(|c| normalize_ident(&c))
.collect(),
rename_map,
is_target_table: false,
db_name: None,
})
}
ast::SelectTable::Sub(from_clause, alias) => {
let before_cols = from_clause_output_columns(from_clause, ctx, false)?;
let _ = rewrite_from_clause_for_column_rename(
from_clause,
ctx,
ctes,
outer_scopes,
changed,
visiting_views,
)?;
let after_cols = from_clause_output_columns(from_clause, ctx, true)?;
let rename_map = build_rename_map(&before_cols, &after_cols, &ctx.old_column_norm);
let qualifiers = alias
.as_ref()
.map(|alias| vec![normalize_ident(alias_name(alias))])
.unwrap_or_default();
Ok(ViewSourceInfo {
qualifiers,
columns_before: before_cols
.into_iter()
.map(|c| normalize_ident(&c))
.collect(),
rename_map,
is_target_table: false,
db_name: None,
})
}
ast::SelectTable::TableCall(_, args, alias) => {
for arg in args {
rewrite_expr_in_scope(
arg,
visible_sources,
outer_scopes,
ctx,
changed,
visiting_views,
)?;
}
let qualifiers = alias
.as_ref()
.map(|alias| vec![normalize_ident(alias_name(alias))])
.unwrap_or_default();
Ok(ViewSourceInfo {
qualifiers,
columns_before: HashSet::default(),
rename_map: HashMap::default(),
is_target_table: false,
db_name: None,
})
}
}
}
fn rewrite_expr_in_scope(
expr: &mut ast::Expr,
sources: &[ViewSourceInfo],
outer_scopes: &[&[ViewSourceInfo]],
ctx: &ViewRewriteCtx,
changed: &mut bool,
visiting_views: &mut HashSet<String>,
) -> Result<()> {
let mut outer_scopes_for_subqueries: Vec<&[ViewSourceInfo]> =
Vec::with_capacity(outer_scopes.len() + 1);
if !sources.is_empty() {
outer_scopes_for_subqueries.push(sources);
}
outer_scopes_for_subqueries.extend_from_slice(outer_scopes);
walk_expr_mut(expr, &mut |e: &mut ast::Expr| -> Result<WalkControl> {
if rewrite_expr_column_ref_view(
e,
sources,
outer_scopes,
&ctx.target_db_norm,
&ctx.old_column_norm,
ctx.new_column,
) {
*changed = true;
}
match e {
ast::Expr::Subquery(select) | ast::Expr::Exists(select) => {
if rewrite_view_select_for_column_rename(
select,
ctx,
outer_scopes_for_subqueries.as_slice(),
visiting_views,
)? {
*changed = true;
}
}
ast::Expr::InSelect { rhs, .. } => {
if rewrite_view_select_for_column_rename(
rhs,
ctx,
outer_scopes_for_subqueries.as_slice(),
visiting_views,
)? {
*changed = true;
}
}
_ => {}
}
Ok(WalkControl::Continue)
})?;
Ok(())
}
fn rewrite_expr_column_ref_view(
expr: &mut ast::Expr,
sources: &[ViewSourceInfo],
outer_scopes: &[&[ViewSourceInfo]],
target_db_norm: &str,
old_column_norm: &str,
new_column: &str,
) -> bool {
let apply_rename = |source: &ViewSourceInfo, set_name: &mut dyn FnMut(String)| {
if source.is_target_table {
set_name(new_column.to_string());
return true;
}
if let Some(mapped) = source.rename_map.get(old_column_norm) {
set_name(mapped.to_string());
return true;
}
false
};
match expr {
ast::Expr::Qualified(ns, col) => {
let ns_norm = normalize_ident(ns.as_str());
if !col.as_str().eq_ignore_ascii_case(old_column_norm) {
return false;
}
let (source, local_ambiguous) =
resolve_qualified(sources, &ns_norm, target_db_norm);
if let Some(source) = source {
return apply_rename(source, &mut |name| {
*col = ast::Name::exact(name);
});
}
if local_ambiguous {
return false;
}
for scope in outer_scopes {
let (source, ambiguous) = resolve_qualified(scope, &ns_norm, target_db_norm);
if let Some(source) = source {
return apply_rename(source, &mut |name| {
*col = ast::Name::exact(name);
});
}
if ambiguous {
return false;
}
}
}
ast::Expr::DoublyQualified(schema, ns, col) => {
let schema_norm = normalize_ident(schema.as_str());
if schema_norm != target_db_norm {
return false;
}
let ns_norm = normalize_ident(ns.as_str());
if !col.as_str().eq_ignore_ascii_case(old_column_norm) {
return false;
}
let (source, local_ambiguous) = resolve_qualified(sources, &ns_norm, &schema_norm);
if let Some(source) = source {
return apply_rename(source, &mut |name| {
*col = ast::Name::exact(name);
});
}
if local_ambiguous {
return false;
}
for scope in outer_scopes {
let (source, ambiguous) = resolve_qualified(scope, &ns_norm, &schema_norm);
if let Some(source) = source {
return apply_rename(source, &mut |name| {
*col = ast::Name::exact(name);
});
}
if ambiguous {
return false;
}
}
}
ast::Expr::Id(col) | ast::Expr::Name(col) => {
if !col.as_str().eq_ignore_ascii_case(old_column_norm) {
return false;
}
let col_norm = normalize_ident(col.as_str());
let (source, local_ambiguous) = resolve_unqualified(sources, &col_norm);
if let Some(source) = source {
return apply_rename(source, &mut |name| {
*expr = ast::Expr::Id(ast::Name::exact(name));
});
}
if local_ambiguous {
return false;
}
for scope in outer_scopes {
let (source, ambiguous) = resolve_unqualified(scope, &col_norm);
if let Some(source) = source {
return apply_rename(source, &mut |name| {
*expr = ast::Expr::Id(ast::Name::exact(name));
});
}
if ambiguous {
return false;
}
}
}
_ => {}
}
false
}
fn resolve_unqualified<'a>(
candidates: &'a [ViewSourceInfo],
old_column_norm: &str,
) -> (Option<&'a ViewSourceInfo>, bool) {
let mut matches = candidates
.iter()
.filter(|s| s.columns_before.contains(old_column_norm));
let Some(first) = matches.next() else {
return (None, false);
};
if matches.next().is_some() {
return (None, true);
}
(Some(first), false)
}
fn resolve_qualified<'a>(
candidates: &'a [ViewSourceInfo],
qualifier: &str,
target_db_norm: &str,
) -> (Option<&'a ViewSourceInfo>, bool) {
let mut matches = candidates.iter().filter(|s| {
s.matches_qualifier(qualifier)
&& s.db_name.as_deref().is_none_or(|db| db == target_db_norm)
});
let Some(first) = matches.next() else {
return (None, false);
};
if matches.next().is_some() {
return (None, true);
}
(Some(first), false)
}
fn rewrite_using_columns(
cols: &mut [ast::Name],
left_sources: &[ViewSourceInfo],
right: &ViewSourceInfo,
old_column_norm: &str,
new_column: &str,
) -> bool {
let mut changed = false;
let left_map = left_sources
.iter()
.find_map(|source| source.rename_map.get(old_column_norm));
let left_has_target = left_sources.iter().any(|source| source.is_target_table);
let right_map = right.rename_map.get(old_column_norm);
let should_rename =
left_has_target || right.is_target_table || left_map.is_some() || right_map.is_some();
if !should_rename {
return false;
}
let replacement = left_map
.or(right_map)
.map(|s| s.as_str())
.unwrap_or(new_column);
for col in cols {
if col.as_str().eq_ignore_ascii_case(old_column_norm) {
*col = ast::Name::exact(replacement.to_string());
changed = true;
}
}
changed
}
fn select_output_columns(
select: &ast::Select,
ctx: &ViewRewriteCtx,
apply_rename: bool,
) -> Result<Vec<String>> {
let view_columns = extract_view_columns(select, ctx.schema)?;
let mut columns = view_columns.columns;
if apply_rename {
let target_norm = ctx.target_table_norm.as_str();
for view_column in &mut columns {
if view_column.table_index == usize::MAX {
continue;
}
let table = &view_columns.tables[view_column.table_index];
if table_name_matches_target(
&table.name,
table.db_name.as_deref(),
target_norm,
&ctx.target_db_norm,
) {
if let Some(ref mut name) = view_column.column.name {
if name.as_str().eq_ignore_ascii_case(ctx.old_column) {
*name = ctx.new_column.to_string();
}
}
}
}
}
Ok(columns
.into_iter()
.map(|vc| vc.column.name.unwrap_or_else(|| "?".to_string()))
.collect())
}
fn apply_explicit_column_names(columns: &mut [String], explicit: &[ast::IndexedColumn]) {
for (i, indexed_col) in explicit.iter().enumerate() {
if let Some(col) = columns.get_mut(i) {
*col = indexed_col.col_name.to_string();
}
}
}
fn from_clause_output_columns(
from_clause: &ast::FromClause,
ctx: &ViewRewriteCtx,
apply_rename: bool,
) -> Result<Vec<String>> {
let dummy_select = ast::Select {
with: None,
body: ast::SelectBody {
select: ast::OneSelect::Select {
distinctness: None,
columns: vec![ast::ResultColumn::Star],
from: Some(from_clause.clone()),
where_clause: None,
group_by: None,
window_clause: Vec::new(),
},
compounds: Vec::new(),
},
order_by: Vec::new(),
limit: None,
};
select_output_columns(&dummy_select, ctx, apply_rename)
}
fn build_rename_map(
before_cols: &[String],
after_cols: &[String],
old_column_norm: &str,
) -> HashMap<String, String> {
let mut map = HashMap::default();
for (before, after) in before_cols.iter().zip(after_cols.iter()) {
if before.as_str().eq_ignore_ascii_case(old_column_norm)
&& !after.as_str().eq_ignore_ascii_case(before.as_str())
{
map.insert(old_column_norm.to_string(), after.to_string());
}
}
map
}
fn build_rename_map_from_columns(
before_cols: &[Column],
after_cols: &[Column],
old_column_norm: &str,
) -> HashMap<String, String> {
if before_cols.len() != after_cols.len() {
return HashMap::default();
}
let mut map = HashMap::default();
for (before, after) in before_cols.iter().zip(after_cols.iter()) {
let Some(before_name) = before.name.as_ref() else {
continue;
};
let Some(after_name) = after.name.as_ref() else {
continue;
};
if before_name.as_str().eq_ignore_ascii_case(old_column_norm)
&& !after_name
.as_str()
.eq_ignore_ascii_case(before_name.as_str())
{
map.insert(old_column_norm.to_string(), after_name.to_string());
}
}
map
}
fn table_name_matches_target(
table_name: &str,
table_db: Option<&str>,
target_table_norm: &str,
target_db_norm: &str,
) -> bool {
if !table_name.eq_ignore_ascii_case(target_table_norm) {
return false;
}
match table_db {
None => true,
Some(db) => db.eq_ignore_ascii_case(target_db_norm),
}
}
fn table_source_columns(schema: &Schema, table_name: &str) -> Option<Vec<String>> {
if let Some(table) = schema.get_table(table_name) {
return Some(
table
.columns()
.iter()
.filter_map(|col| col.name.clone())
.collect(),
);
}
let table_norm = normalize_ident(table_name);
if let Some(view) = schema.views.get(&table_norm) {
return Some(
view.columns
.iter()
.filter_map(|col| col.name.clone())
.collect(),
);
}
None
}
}
pub use rename_column_view::{rewrite_view_sql_for_column_rename, RewrittenView};
pub fn rewrite_check_expr_table_refs(expr: &mut ast::Expr, from: &str, to: &str) {
let _ = walk_expr_mut(
expr,
&mut |e: &mut ast::Expr| -> crate::Result<WalkControl> {
match e {
ast::Expr::Qualified(tbl, col) => {
if tbl.as_str().eq_ignore_ascii_case(from) {
let col = col.clone();
*e = ast::Expr::Qualified(ast::Name::exact(to.to_owned()), col);
}
}
ast::Expr::Exists(select) | ast::Expr::Subquery(select) => {
rewrite_select_table_refs(select, from, to);
}
ast::Expr::InSelect { rhs, .. } => {
rewrite_select_table_refs(rhs, from, to);
}
ast::Expr::InTable { rhs, .. } => {
if rhs.name.as_str().eq_ignore_ascii_case(from) {
rhs.name = ast::Name::exact(to.to_owned());
}
}
_ => {}
}
Ok(WalkControl::Continue)
},
);
}
pub fn rewrite_column_references_if_needed(
col: &mut ast::ColumnDefinition,
table: &str,
from: &str,
to: &str,
) -> Result<()> {
for cc in &mut col.constraints {
match &mut cc.constraint {
ast::ColumnConstraint::ForeignKey { clause, .. } => {
rewrite_fk_parent_cols_if_self_ref(clause, table, from, to);
}
ast::ColumnConstraint::Check(expr) => {
rename_identifiers(expr, from, to);
}
ast::ColumnConstraint::Generated { expr, .. } => {
rename_identifiers(expr, from, to);
}
_ => {}
}
}
Ok(())
}
pub fn rewrite_column_level_fk_parent_columns_if_needed(
col: &mut ast::ColumnDefinition,
table: &str,
from: &str,
to: &str,
) {
for cc in &mut col.constraints {
if let ast::ColumnConstraint::ForeignKey { clause, .. } = &mut cc.constraint {
rewrite_fk_parent_cols_if_self_ref(clause, table, from, to);
}
}
}
pub fn rewrite_fk_parent_table_if_needed(
clause: &mut ast::ForeignKeyClause,
old_tbl: &str,
new_tbl: &str,
) -> bool {
if clause.tbl_name.as_str().eq_ignore_ascii_case(old_tbl) {
clause.tbl_name = ast::Name::exact(new_tbl.to_owned());
return true;
}
false
}
pub fn rewrite_inline_col_fk_target_if_needed(
col: &mut ast::ColumnDefinition,
old_tbl: &str,
new_tbl: &str,
) -> bool {
let mut changed = false;
for cc in &mut col.constraints {
if let ast::NamedColumnConstraint {
constraint: ast::ColumnConstraint::ForeignKey { clause, .. },
..
} = cc
{
changed |= rewrite_fk_parent_table_if_needed(clause, old_tbl, new_tbl);
}
}
changed
}
pub fn rewrite_trigger_cmd_table_refs(cmd: &mut ast::TriggerCmd, old_tbl: &str, new_tbl: &str) {
match cmd {
ast::TriggerCmd::Update {
tbl_name,
sets,
from,
where_clause,
..
} => {
if tbl_name.as_str().eq_ignore_ascii_case(old_tbl) {
*tbl_name = ast::Name::exact(new_tbl.to_owned());
}
for set in sets {
rewrite_check_expr_table_refs(&mut set.expr, old_tbl, new_tbl);
}
if let Some(ref mut from) = from {
rewrite_from_clause_table_refs(from, old_tbl, new_tbl);
}
if let Some(ref mut wc) = where_clause {
rewrite_check_expr_table_refs(wc, old_tbl, new_tbl);
}
}
ast::TriggerCmd::Insert {
tbl_name,
select,
upsert,
..
} => {
if tbl_name.as_str().eq_ignore_ascii_case(old_tbl) {
*tbl_name = ast::Name::exact(new_tbl.to_owned());
}
rewrite_select_table_refs(select, old_tbl, new_tbl);
if let Some(ref mut upsert) = upsert {
rewrite_upsert_table_refs(upsert, old_tbl, new_tbl);
}
}
ast::TriggerCmd::Delete {
tbl_name,
where_clause,
} => {
if tbl_name.as_str().eq_ignore_ascii_case(old_tbl) {
*tbl_name = ast::Name::exact(new_tbl.to_owned());
}
if let Some(ref mut wc) = where_clause {
rewrite_check_expr_table_refs(wc, old_tbl, new_tbl);
}
}
ast::TriggerCmd::Select(select) => {
rewrite_select_table_refs(select, old_tbl, new_tbl);
}
}
}
pub(crate) fn output_column_aliases(one_select: &ast::OneSelect) -> Vec<String> {
let ast::OneSelect::Select { columns, .. } = one_select else {
return Vec::new();
};
columns
.iter()
.filter_map(|col| match col {
ast::ResultColumn::Expr(_, Some(alias)) => Some(normalize_ident(alias.name().as_str())),
ast::ResultColumn::Expr(expr, None) => match expr.as_ref() {
ast::Expr::Id(name) | ast::Expr::Name(name) => Some(normalize_ident(name.as_str())),
ast::Expr::Qualified(_, col) | ast::Expr::DoublyQualified(_, _, col) => {
Some(normalize_ident(col.as_str()))
}
_ => None,
},
_ => None,
})
.collect()
}
pub(crate) fn is_order_by_alias_ref(
body: &ast::SelectBody,
expr: &ast::Expr,
old_col: &str,
) -> bool {
let name = match expr {
ast::Expr::Id(n) | ast::Expr::Name(n) => n.as_str(),
_ => return false,
};
if !name.eq_ignore_ascii_case(old_col) {
return false;
}
one_select_has_explicit_alias(&body.select, name)
|| body
.compounds
.iter()
.any(|c| one_select_has_explicit_alias(&c.select, name))
}
fn one_select_has_explicit_alias(one_select: &ast::OneSelect, name: &str) -> bool {
let ast::OneSelect::Select { columns, .. } = one_select else {
return false;
};
columns.iter().any(|col| match col {
ast::ResultColumn::Expr(_, Some(alias)) if alias.is_explicit() => {
alias.name().as_str().eq_ignore_ascii_case(name)
}
_ => false,
})
}
fn rewrite_select_column_refs_scoped(
select: &mut ast::Select,
target_table: &str,
trigger_table: &str,
old_col: &str,
new_col: &str,
target_qualifiers: &mut Vec<String>,
) {
if let Some(with_clause) = &mut select.with {
for cte in &mut with_clause.ctes {
rewrite_select_column_refs_scoped(
&mut cte.select,
target_table,
trigger_table,
old_col,
new_col,
target_qualifiers,
);
}
}
rewrite_one_select_column_refs_scoped(
&mut select.body.select,
target_table,
trigger_table,
old_col,
new_col,
target_qualifiers,
);
for compound in &mut select.body.compounds {
rewrite_one_select_column_refs_scoped(
&mut compound.select,
target_table,
trigger_table,
old_col,
new_col,
target_qualifiers,
);
}
let added = match &select.body.select {
ast::OneSelect::Select { from, .. } => {
let local = from_clause_target_qualifiers(from, target_table);
extend_qualifiers_scoped(target_qualifiers, &local)
}
_ => 0,
};
let rename_unqualified =
!target_qualifiers.is_empty() || target_table.eq_ignore_ascii_case(trigger_table);
let body = &select.body;
for col in &mut select.order_by {
if is_order_by_alias_ref(body, &col.expr, old_col) {
continue;
}
rename_identifiers_scoped_inner(
&mut col.expr,
target_table,
trigger_table,
old_col,
new_col,
rename_unqualified,
Some(target_qualifiers),
);
}
if let Some(limit) = &mut select.limit {
rename_identifiers_scoped_inner(
&mut limit.expr,
target_table,
trigger_table,
old_col,
new_col,
rename_unqualified,
Some(target_qualifiers),
);
if let Some(offset) = &mut limit.offset {
rename_identifiers_scoped_inner(
offset,
target_table,
trigger_table,
old_col,
new_col,
rename_unqualified,
Some(target_qualifiers),
);
}
}
target_qualifiers.truncate(target_qualifiers.len() - added);
}
fn from_clause_target_qualifiers(
from: &Option<ast::FromClause>,
target_table: &str,
) -> Vec<String> {
match from {
Some(from_clause) => from_clause_target_qualifiers_inner(from_clause, target_table),
None => Vec::new(),
}
}
fn from_clause_target_qualifiers_inner(
from_clause: &ast::FromClause,
target_table: &str,
) -> Vec<String> {
let target_table = normalize_ident(target_table);
let mut qualifiers = Vec::new();
let mut seen = HashSet::default();
collect_target_qualifiers(
&from_clause.select,
&target_table,
&mut qualifiers,
&mut seen,
);
for join in &from_clause.joins {
collect_target_qualifiers(&join.table, &target_table, &mut qualifiers, &mut seen);
}
qualifiers
}
fn collect_target_qualifiers(
st: &ast::SelectTable,
target_table: &str,
qualifiers: &mut Vec<String>,
seen: &mut HashSet<String>,
) {
let ast::SelectTable::Table(name, alias, _) = st else {
return;
};
if !name.name.as_str().eq_ignore_ascii_case(target_table) {
return;
}
if seen.insert(target_table.to_string()) {
qualifiers.push(target_table.to_string());
}
if let Some(alias) = alias {
let alias_norm = normalize_ident(alias.name().as_str());
if seen.insert(alias_norm.clone()) {
qualifiers.push(alias_norm);
}
}
}
fn rewrite_one_select_column_refs_scoped(
one: &mut ast::OneSelect,
target_table: &str,
trigger_table: &str,
old_col: &str,
new_col: &str,
target_qualifiers: &mut Vec<String>,
) {
match one {
ast::OneSelect::Select {
from,
where_clause,
columns,
group_by,
window_clause,
..
} => {
let local = from_clause_target_qualifiers(from, target_table);
let added = extend_qualifiers_scoped(target_qualifiers, &local);
let rename_unqualified =
!target_qualifiers.is_empty() || target_table.eq_ignore_ascii_case(trigger_table);
if let Some(ref mut from) = from {
rewrite_from_clause_column_refs_scoped(
from,
target_table,
trigger_table,
old_col,
new_col,
target_qualifiers,
);
}
if let Some(ref mut wc) = where_clause {
rename_identifiers_scoped_inner(
wc,
target_table,
trigger_table,
old_col,
new_col,
rename_unqualified,
Some(target_qualifiers),
);
}
for col in columns {
if let ast::ResultColumn::Expr(ref mut expr, _) = col {
rename_result_identifiers_scoped(
expr,
target_table,
trigger_table,
old_col,
new_col,
rename_unqualified,
Some(target_qualifiers),
);
}
}
if let Some(ref mut gb) = group_by {
for expr in &mut gb.exprs {
rename_identifiers_scoped_inner(
expr,
target_table,
trigger_table,
old_col,
new_col,
rename_unqualified,
Some(target_qualifiers),
);
}
if let Some(ref mut having) = gb.having {
rename_identifiers_scoped_inner(
having,
target_table,
trigger_table,
old_col,
new_col,
rename_unqualified,
Some(target_qualifiers),
);
}
}
for window_def in window_clause {
rewrite_window_column_refs_scoped(
&mut window_def.window,
target_table,
trigger_table,
old_col,
new_col,
target_qualifiers,
);
}
target_qualifiers.truncate(target_qualifiers.len() - added);
}
ast::OneSelect::Values(rows) => {
for row in rows {
for expr in row {
rename_identifiers_scoped(expr, target_table, trigger_table, old_col, new_col);
}
}
}
}
}
fn rename_result_identifiers_scoped(
expr: &mut ast::Expr,
target_table: &str,
trigger_table: &str,
from: &str,
to: &str,
rename_unqualified: bool,
target_qualifiers: Option<&[String]>,
) {
let is_renaming_trigger_table = target_table.eq_ignore_ascii_case(trigger_table);
let _ = walk_expr_mut(
expr,
&mut |e: &mut ast::Expr| -> crate::Result<WalkControl> {
match e {
ast::Expr::Exists(_) => return Ok(WalkControl::SkipChildren),
ast::Expr::Subquery(select) => {
let mut quals = target_qualifiers.unwrap_or(&[]).to_vec();
rewrite_select_column_refs_scoped(
select,
target_table,
trigger_table,
from,
to,
&mut quals,
);
}
ast::Expr::InSelect { rhs, .. } => {
let mut quals = target_qualifiers.unwrap_or(&[]).to_vec();
rewrite_select_column_refs_scoped(
rhs,
target_table,
trigger_table,
from,
to,
&mut quals,
);
}
ast::Expr::Id(ref name) | ast::Expr::Name(ref name)
if rename_unqualified && name.as_str().eq_ignore_ascii_case(from) =>
{
*e = ast::Expr::Id(ast::Name::exact(to.to_owned()));
}
ast::Expr::Qualified(ref tbl, ref col_name)
if col_name.as_str().eq_ignore_ascii_case(from) =>
{
let tbl_norm = normalize_ident(tbl.as_str());
let should_rename = if tbl_norm == "new" || tbl_norm == "old" {
is_renaming_trigger_table
} else {
target_qualifiers.is_some_and(|qualifiers| qualifiers.contains(&tbl_norm))
|| tbl_norm.eq_ignore_ascii_case(target_table)
};
if should_rename {
let tbl = tbl.clone();
*e = ast::Expr::Qualified(tbl, ast::Name::exact(to.to_owned()));
}
}
_ => {}
}
Ok(WalkControl::Continue)
},
);
}
fn rewrite_window_column_refs_scoped(
window: &mut ast::Window,
target_table: &str,
trigger_table: &str,
old_col: &str,
new_col: &str,
visible_target_qualifiers: &[String],
) {
let rename_unqualified =
!visible_target_qualifiers.is_empty() || target_table.eq_ignore_ascii_case(trigger_table);
for expr in &mut window.partition_by {
rename_identifiers_scoped_inner(
expr,
target_table,
trigger_table,
old_col,
new_col,
rename_unqualified,
Some(visible_target_qualifiers),
);
}
for sorted_col in &mut window.order_by {
rename_identifiers_scoped_inner(
&mut sorted_col.expr,
target_table,
trigger_table,
old_col,
new_col,
rename_unqualified,
Some(visible_target_qualifiers),
);
}
}
fn rewrite_from_clause_column_refs_scoped(
from: &mut ast::FromClause,
target_table: &str,
trigger_table: &str,
old_col: &str,
new_col: &str,
target_qualifiers: &mut Vec<String>,
) {
let local = from_clause_target_qualifiers_inner(from, target_table);
let added = extend_qualifiers_scoped(target_qualifiers, &local);
let rename_unqualified =
!target_qualifiers.is_empty() || target_table.eq_ignore_ascii_case(trigger_table);
rewrite_select_table_entry_column_refs_scoped(
&mut from.select,
target_table,
trigger_table,
old_col,
new_col,
target_qualifiers,
);
for join in &mut from.joins {
rewrite_select_table_entry_column_refs_scoped(
&mut join.table,
target_table,
trigger_table,
old_col,
new_col,
target_qualifiers,
);
if let Some(ast::JoinConstraint::On(ref mut expr)) = join.constraint {
rename_identifiers_scoped_inner(
expr,
target_table,
trigger_table,
old_col,
new_col,
rename_unqualified,
Some(target_qualifiers),
);
}
}
target_qualifiers.truncate(target_qualifiers.len() - added);
}
fn rewrite_select_table_entry_column_refs_scoped(
st: &mut ast::SelectTable,
target_table: &str,
trigger_table: &str,
old_col: &str,
new_col: &str,
target_qualifiers: &mut Vec<String>,
) {
match st {
ast::SelectTable::TableCall(_, ref mut args, _) => {
for arg in args {
rename_identifiers_scoped_inner(
arg,
target_table,
trigger_table,
old_col,
new_col,
true,
Some(target_qualifiers),
);
}
}
ast::SelectTable::Select(ref mut select, _) => {
rewrite_select_column_refs_scoped(
select,
target_table,
trigger_table,
old_col,
new_col,
target_qualifiers,
);
}
ast::SelectTable::Sub(ref mut from, _) => {
rewrite_from_clause_column_refs_scoped(
from,
target_table,
trigger_table,
old_col,
new_col,
target_qualifiers,
);
}
ast::SelectTable::Table(..) => {}
}
}
fn extend_qualifiers_scoped(qualifiers: &mut Vec<String>, local: &[String]) -> usize {
let before = qualifiers.len();
for qualifier in local {
if !qualifiers.iter().any(|q| q == qualifier) {
qualifiers.push(qualifier.clone());
}
}
qualifiers.len() - before
}
fn expr_still_references_renamed_column(
expr: &ast::Expr,
target_table: &str,
trigger_table: &str,
old_col: &str,
rename_unqualified: bool,
visible_target_qualifiers: &[String],
) -> bool {
let mut found = false;
let _ = walk_expr(expr, &mut |e: &ast::Expr| -> crate::Result<WalkControl> {
if found {
return Ok(WalkControl::Continue);
}
match e {
ast::Expr::Subquery(select) | ast::Expr::Exists(select) => {
let mut quals = visible_target_qualifiers.to_vec();
found = select_still_references_renamed_column(
select,
target_table,
trigger_table,
old_col,
&mut quals,
);
}
ast::Expr::InSelect { rhs, .. } => {
let mut quals = visible_target_qualifiers.to_vec();
found = select_still_references_renamed_column(
rhs,
target_table,
trigger_table,
old_col,
&mut quals,
);
}
ast::Expr::Qualified(ns, col) | ast::Expr::DoublyQualified(_, ns, col) => {
if col.as_str().eq_ignore_ascii_case(old_col) {
let ns_norm = normalize_ident(ns.as_str());
if ((ns_norm == "new" || ns_norm == "old")
&& target_table.eq_ignore_ascii_case(trigger_table))
|| visible_target_qualifiers.contains(&ns_norm)
|| (target_table.eq_ignore_ascii_case(trigger_table)
&& ns_norm.eq_ignore_ascii_case(trigger_table))
{
found = true;
}
}
}
ast::Expr::Id(name) | ast::Expr::Name(name) => {
if rename_unqualified && name.as_str().eq_ignore_ascii_case(old_col) {
found = true;
}
}
_ => {}
}
Ok(WalkControl::Continue)
});
found
}
fn one_select_still_references_renamed_column(
one: &ast::OneSelect,
target_table: &str,
trigger_table: &str,
old_col: &str,
target_qualifiers: &mut Vec<String>,
) -> bool {
match one {
ast::OneSelect::Select {
from,
where_clause,
columns,
group_by,
..
} => {
let local = from_clause_target_qualifiers(from, target_table);
let added = extend_qualifiers_scoped(target_qualifiers, &local);
let rename_unqualified =
!target_qualifiers.is_empty() || target_table.eq_ignore_ascii_case(trigger_table);
let mut found = false;
if let Some(from_clause) = from {
if from_clause_still_references_renamed_column(
from_clause,
target_table,
trigger_table,
old_col,
target_qualifiers,
) {
found = true;
}
}
if !found {
if let Some(where_expr) = where_clause {
if expr_still_references_renamed_column(
where_expr,
target_table,
trigger_table,
old_col,
rename_unqualified,
target_qualifiers,
) {
found = true;
}
}
}
if !found {
for col in columns {
if let ast::ResultColumn::Expr(expr, _) = col {
if expr_still_references_renamed_column(
expr,
target_table,
trigger_table,
old_col,
rename_unqualified,
target_qualifiers,
) {
found = true;
break;
}
}
}
}
if !found {
if let Some(group_by) = group_by {
for expr in &group_by.exprs {
if expr_still_references_renamed_column(
expr,
target_table,
trigger_table,
old_col,
rename_unqualified,
target_qualifiers,
) {
found = true;
break;
}
}
if !found {
if let Some(having) = &group_by.having {
if expr_still_references_renamed_column(
having,
target_table,
trigger_table,
old_col,
rename_unqualified,
target_qualifiers,
) {
found = true;
}
}
}
}
}
target_qualifiers.truncate(target_qualifiers.len() - added);
found
}
ast::OneSelect::Values(rows) => {
for row in rows {
for expr in row {
if expr_still_references_renamed_column(
expr,
target_table,
trigger_table,
old_col,
true,
target_qualifiers,
) {
return true;
}
}
}
false
}
}
}
fn select_still_references_renamed_column(
select: &ast::Select,
target_table: &str,
trigger_table: &str,
old_col: &str,
target_qualifiers: &mut Vec<String>,
) -> bool {
if let Some(with_clause) = &select.with {
for cte in &with_clause.ctes {
if select_still_references_renamed_column(
&cte.select,
target_table,
trigger_table,
old_col,
target_qualifiers,
) {
return true;
}
}
}
if one_select_still_references_renamed_column(
&select.body.select,
target_table,
trigger_table,
old_col,
target_qualifiers,
) {
return true;
}
for compound in &select.body.compounds {
if one_select_still_references_renamed_column(
&compound.select,
target_table,
trigger_table,
old_col,
target_qualifiers,
) {
return true;
}
}
let added = match &select.body.select {
ast::OneSelect::Select { from, .. } => {
let local = from_clause_target_qualifiers(from, target_table);
extend_qualifiers_scoped(target_qualifiers, &local)
}
_ => 0,
};
let rename_unqualified =
!target_qualifiers.is_empty() || target_table.eq_ignore_ascii_case(trigger_table);
let body = &select.body;
let mut found = false;
for sorted_col in &select.order_by {
if is_order_by_alias_ref(body, &sorted_col.expr, old_col) {
continue;
}
if expr_still_references_renamed_column(
&sorted_col.expr,
target_table,
trigger_table,
old_col,
rename_unqualified,
target_qualifiers,
) {
found = true;
break;
}
}
if !found {
if let Some(limit) = &select.limit {
if expr_still_references_renamed_column(
&limit.expr,
target_table,
trigger_table,
old_col,
rename_unqualified,
target_qualifiers,
) {
found = true;
}
if !found {
if let Some(offset) = &limit.offset {
if expr_still_references_renamed_column(
offset,
target_table,
trigger_table,
old_col,
rename_unqualified,
target_qualifiers,
) {
found = true;
}
}
}
}
}
target_qualifiers.truncate(target_qualifiers.len() - added);
found
}
fn select_table_still_references_renamed_column(
st: &ast::SelectTable,
target_table: &str,
trigger_table: &str,
old_col: &str,
target_qualifiers: &mut Vec<String>,
) -> bool {
match st {
ast::SelectTable::TableCall(_, args, _) => args.iter().any(|arg| {
expr_still_references_renamed_column(
arg,
target_table,
trigger_table,
old_col,
true,
target_qualifiers,
)
}),
ast::SelectTable::Select(select, _) => select_still_references_renamed_column(
select,
target_table,
trigger_table,
old_col,
target_qualifiers,
),
ast::SelectTable::Sub(from, _) => from_clause_still_references_renamed_column(
from,
target_table,
trigger_table,
old_col,
target_qualifiers,
),
ast::SelectTable::Table(..) => false,
}
}
fn from_clause_still_references_renamed_column(
from: &ast::FromClause,
target_table: &str,
trigger_table: &str,
old_col: &str,
target_qualifiers: &mut Vec<String>,
) -> bool {
let local = from_clause_target_qualifiers_inner(from, target_table);
let added = extend_qualifiers_scoped(target_qualifiers, &local);
let rename_unqualified =
!target_qualifiers.is_empty() || target_table.eq_ignore_ascii_case(trigger_table);
let mut found = false;
if select_table_still_references_renamed_column(
&from.select,
target_table,
trigger_table,
old_col,
target_qualifiers,
) {
found = true;
}
if !found {
for join in &from.joins {
if select_table_still_references_renamed_column(
&join.table,
target_table,
trigger_table,
old_col,
target_qualifiers,
) {
found = true;
break;
}
if let Some(ast::JoinConstraint::On(expr)) = &join.constraint {
if expr_still_references_renamed_column(
expr,
target_table,
trigger_table,
old_col,
rename_unqualified,
target_qualifiers,
) {
found = true;
break;
}
}
}
}
target_qualifiers.truncate(target_qualifiers.len() - added);
found
}
pub fn trigger_still_references_renamed_column(
trigger: &crate::schema::Trigger,
target_table: &str,
old_col: &str,
) -> bool {
if trigger.table_name.eq_ignore_ascii_case(target_table) {
if let ast::TriggerEvent::UpdateOf(cols) = &trigger.event {
if cols
.iter()
.any(|col| col.as_str().eq_ignore_ascii_case(old_col))
{
return true;
}
}
}
if let Some(when_clause) = &trigger.when_clause {
if expr_still_references_renamed_column(
when_clause,
target_table,
&trigger.table_name,
old_col,
false,
&[],
) {
return true;
}
}
for cmd in &trigger.commands {
match cmd {
ast::TriggerCmd::Update {
tbl_name,
sets,
from,
where_clause,
..
} => {
let targets_renamed_table = tbl_name.as_str().eq_ignore_ascii_case(target_table);
let visible_target_qualifiers = from_clause_target_qualifiers(from, target_table);
if targets_renamed_table
&& sets.iter().any(|set| {
set.col_names
.iter()
.any(|col_name| col_name.as_str().eq_ignore_ascii_case(old_col))
})
{
return true;
}
for set in sets {
if expr_still_references_renamed_column(
&set.expr,
target_table,
&trigger.table_name,
old_col,
targets_renamed_table,
&visible_target_qualifiers,
) {
return true;
}
}
if let Some(where_clause) = where_clause {
if expr_still_references_renamed_column(
where_clause,
target_table,
&trigger.table_name,
old_col,
targets_renamed_table,
&visible_target_qualifiers,
) {
return true;
}
}
if let Some(from_clause) = from {
if from_clause_still_references_renamed_column(
from_clause,
target_table,
&trigger.table_name,
old_col,
&mut Vec::new(),
) {
return true;
}
}
}
ast::TriggerCmd::Insert {
tbl_name,
col_names,
select,
upsert,
..
} => {
if tbl_name.as_str().eq_ignore_ascii_case(target_table)
&& col_names
.iter()
.any(|col_name| col_name.as_str().eq_ignore_ascii_case(old_col))
{
return true;
}
if select_still_references_renamed_column(
select,
target_table,
&trigger.table_name,
old_col,
&mut Vec::new(),
) {
return true;
}
if let Some(upsert) = upsert {
if let Some(index) = &upsert.index {
for target in &index.targets {
if expr_still_references_renamed_column(
&target.expr,
target_table,
&trigger.table_name,
old_col,
tbl_name.as_str().eq_ignore_ascii_case(target_table),
&[],
) {
return true;
}
}
if let Some(where_clause) = &index.where_clause {
if expr_still_references_renamed_column(
where_clause,
target_table,
&trigger.table_name,
old_col,
tbl_name.as_str().eq_ignore_ascii_case(target_table),
&[],
) {
return true;
}
}
}
}
}
ast::TriggerCmd::Delete {
tbl_name,
where_clause,
} => {
if let Some(where_clause) = where_clause {
if expr_still_references_renamed_column(
where_clause,
target_table,
&trigger.table_name,
old_col,
tbl_name.as_str().eq_ignore_ascii_case(target_table),
&[],
) {
return true;
}
}
}
ast::TriggerCmd::Select(select) => {
if select_still_references_renamed_column(
select,
target_table,
&trigger.table_name,
old_col,
&mut Vec::new(),
) {
return true;
}
}
}
}
false
}
fn rename_excluded_column_refs(expr: &mut ast::Expr, old_col: &str, new_col: &str) {
let _ = walk_expr_mut(
expr,
&mut |e: &mut ast::Expr| -> crate::Result<WalkControl> {
if let ast::Expr::Qualified(ns, col) | ast::Expr::DoublyQualified(_, ns, col) = e {
if ns.as_str().eq_ignore_ascii_case("excluded")
&& col.as_str().eq_ignore_ascii_case(old_col)
{
*col = ast::Name::exact(new_col.to_owned());
}
}
Ok(WalkControl::Continue)
},
);
}
fn rewrite_upsert_column_refs_scoped(
upsert: &mut ast::Upsert,
table: &str,
trigger_table: &str,
insert_table: &str,
old_col: &str,
new_col: &str,
) {
let insert_targets_renamed_table = insert_table.eq_ignore_ascii_case(table);
let rewrite_expr = |expr: &mut ast::Expr| {
if insert_targets_renamed_table {
rename_identifiers_scoped(expr, table, trigger_table, old_col, new_col);
rename_excluded_column_refs(expr, old_col, new_col);
} else {
rename_identifiers_scoped_when_clause(expr, table, trigger_table, old_col, new_col);
}
};
if let Some(ref mut index) = upsert.index {
for target in &mut index.targets {
rewrite_expr(&mut target.expr);
}
if let Some(ref mut wc) = index.where_clause {
rewrite_expr(wc);
}
}
if let ast::UpsertDo::Set {
ref mut sets,
ref mut where_clause,
} = upsert.do_clause
{
for set in sets {
if insert_targets_renamed_table {
for col_name in &mut set.col_names {
if col_name.as_str().eq_ignore_ascii_case(old_col) {
*col_name = ast::Name::exact(new_col.to_owned());
}
}
}
rewrite_expr(&mut set.expr);
}
if let Some(ref mut wc) = where_clause {
rewrite_expr(wc);
}
}
if let Some(ref mut next) = upsert.next {
rewrite_upsert_column_refs_scoped(
next,
table,
trigger_table,
insert_table,
old_col,
new_col,
);
}
}
pub fn rewrite_trigger_cmd_column_refs(
cmd: &mut ast::TriggerCmd,
table: &str,
trigger_table: &str,
old_col: &str,
new_col: &str,
) {
match cmd {
ast::TriggerCmd::Update {
tbl_name,
sets,
from,
where_clause,
..
} => {
let targets_renamed_table = tbl_name.as_str().eq_ignore_ascii_case(table);
if targets_renamed_table {
for set in sets {
for col_name in &mut set.col_names {
if col_name.as_str().eq_ignore_ascii_case(old_col) {
*col_name = ast::Name::exact(new_col.to_owned());
}
}
rename_identifiers_scoped(
&mut set.expr,
table,
trigger_table,
old_col,
new_col,
);
}
if let Some(ref mut wc) = where_clause {
rename_identifiers_scoped(wc, table, trigger_table, old_col, new_col);
}
} else {
for set in sets {
rename_identifiers_scoped_when_clause(
&mut set.expr,
table,
trigger_table,
old_col,
new_col,
);
}
if let Some(ref mut wc) = where_clause {
rename_identifiers_scoped_when_clause(
wc,
table,
trigger_table,
old_col,
new_col,
);
}
}
if let Some(ref mut from) = from {
rewrite_from_clause_column_refs_scoped(
from,
table,
trigger_table,
old_col,
new_col,
&mut Vec::new(),
);
}
}
ast::TriggerCmd::Insert {
tbl_name,
col_names,
select,
upsert,
..
} => {
let targets_renamed_table = tbl_name.as_str().eq_ignore_ascii_case(table);
if targets_renamed_table {
for col_name in col_names {
if col_name.as_str().eq_ignore_ascii_case(old_col) {
*col_name = ast::Name::exact(new_col.to_owned());
}
}
}
rewrite_select_column_refs_scoped(
select,
table,
trigger_table,
old_col,
new_col,
&mut Vec::new(),
);
if let Some(ref mut upsert) = upsert {
rewrite_upsert_column_refs_scoped(
upsert,
table,
trigger_table,
tbl_name.as_str(),
old_col,
new_col,
);
}
}
ast::TriggerCmd::Delete {
tbl_name,
where_clause,
} => {
let targets_renamed_table = tbl_name.as_str().eq_ignore_ascii_case(table);
if targets_renamed_table {
if let Some(ref mut wc) = where_clause {
rename_identifiers_scoped(wc, table, trigger_table, old_col, new_col);
}
} else if let Some(ref mut wc) = where_clause {
rename_identifiers_scoped_when_clause(wc, table, trigger_table, old_col, new_col);
}
}
ast::TriggerCmd::Select(select) => {
rewrite_select_column_refs_scoped(
select,
table,
trigger_table,
old_col,
new_col,
&mut Vec::new(),
);
}
}
}
fn rewrite_select_table_refs(select: &mut ast::Select, old_tbl: &str, new_tbl: &str) {
if let Some(with_clause) = &mut select.with {
for cte in &mut with_clause.ctes {
rewrite_select_table_refs(&mut cte.select, old_tbl, new_tbl);
}
}
rewrite_one_select_table_refs(&mut select.body.select, old_tbl, new_tbl);
for compound in &mut select.body.compounds {
rewrite_one_select_table_refs(&mut compound.select, old_tbl, new_tbl);
}
for col in &mut select.order_by {
rewrite_check_expr_table_refs(&mut col.expr, old_tbl, new_tbl);
}
}
fn rewrite_one_select_table_refs(one: &mut ast::OneSelect, old_tbl: &str, new_tbl: &str) {
match one {
ast::OneSelect::Select {
from,
where_clause,
columns,
group_by,
..
} => {
if let Some(ref mut from) = from {
rewrite_from_clause_table_refs(from, old_tbl, new_tbl);
}
if let Some(ref mut wc) = where_clause {
rewrite_check_expr_table_refs(wc, old_tbl, new_tbl);
}
for col in columns {
match col {
ast::ResultColumn::Expr(ref mut expr, _) => {
rewrite_check_expr_table_refs(expr, old_tbl, new_tbl);
}
ast::ResultColumn::TableStar(ref mut name) => {
if name.as_str().eq_ignore_ascii_case(old_tbl) {
*name = ast::Name::exact(new_tbl.to_owned());
}
}
ast::ResultColumn::Star => {}
}
}
if let Some(ref mut gb) = group_by {
for expr in &mut gb.exprs {
rewrite_check_expr_table_refs(expr, old_tbl, new_tbl);
}
if let Some(ref mut having) = gb.having {
rewrite_check_expr_table_refs(having, old_tbl, new_tbl);
}
}
}
ast::OneSelect::Values(rows) => {
for row in rows {
for expr in row {
rewrite_check_expr_table_refs(expr, old_tbl, new_tbl);
}
}
}
}
}
fn rewrite_from_clause_table_refs(from: &mut ast::FromClause, old_tbl: &str, new_tbl: &str) {
rewrite_select_table_entry_table_refs(&mut from.select, old_tbl, new_tbl);
for join in &mut from.joins {
rewrite_select_table_entry_table_refs(&mut join.table, old_tbl, new_tbl);
if let Some(ast::JoinConstraint::On(ref mut expr)) = join.constraint {
rewrite_check_expr_table_refs(expr, old_tbl, new_tbl);
}
}
}
fn rewrite_select_table_entry_table_refs(st: &mut ast::SelectTable, old_tbl: &str, new_tbl: &str) {
match st {
ast::SelectTable::Table(ref mut name, _, _) => {
if name.name.as_str().eq_ignore_ascii_case(old_tbl) {
name.name = ast::Name::exact(new_tbl.to_owned());
}
}
ast::SelectTable::TableCall(ref mut name, ref mut args, _) => {
if name.name.as_str().eq_ignore_ascii_case(old_tbl) {
name.name = ast::Name::exact(new_tbl.to_owned());
}
for arg in args {
rewrite_check_expr_table_refs(arg, old_tbl, new_tbl);
}
}
ast::SelectTable::Select(ref mut select, _) => {
rewrite_select_table_refs(select, old_tbl, new_tbl);
}
ast::SelectTable::Sub(ref mut from, _) => {
rewrite_from_clause_table_refs(from, old_tbl, new_tbl);
}
}
}
fn rewrite_upsert_table_refs(upsert: &mut ast::Upsert, old_tbl: &str, new_tbl: &str) {
if let Some(ref mut index) = upsert.index {
if let Some(ref mut wc) = index.where_clause {
rewrite_check_expr_table_refs(wc, old_tbl, new_tbl);
}
}
if let ast::UpsertDo::Set {
ref mut sets,
ref mut where_clause,
} = upsert.do_clause
{
for set in sets {
rewrite_check_expr_table_refs(&mut set.expr, old_tbl, new_tbl);
}
if let Some(ref mut wc) = where_clause {
rewrite_check_expr_table_refs(wc, old_tbl, new_tbl);
}
}
if let Some(ref mut next) = upsert.next {
rewrite_upsert_table_refs(next, old_tbl, new_tbl);
}
}
#[cfg(test)]
pub mod tests {
use super::*;
use crate::schema::{BTreeTable, Type as SchemaValueType};
use turso_parser::ast::{self, Expr, FunctionTail, Literal, Name, Operator::*, Type, Variable};
use turso_parser::parser::Parser;
#[test]
fn test_normalize_ident() {
assert_eq!(normalize_ident("foo"), "foo");
assert_eq!(normalize_ident("FOO"), "foo");
assert_eq!(normalize_ident("ὈΔΥΣΣΕΎΣ"), "ὀδυσσεύς");
}
fn schema_with_tables(create_table_sqls: &[&str]) -> Schema {
let mut schema = Schema::new();
for (index, create_table_sql) in create_table_sqls.iter().enumerate() {
let root_page = i64::try_from(index).expect("test table index should fit in i64") + 2;
let table = BTreeTable::from_sql(create_table_sql, root_page)
.expect("test CREATE TABLE should parse");
schema
.add_btree_table(std::sync::Arc::new(table))
.expect("test table should be added to schema");
}
schema
}
fn schema_with_table(create_table_sql: &str) -> Schema {
schema_with_tables(&[create_table_sql])
}
fn parse_select_from(sql: &str) -> Option<ast::FromClause> {
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser
.next_cmd()
.expect("test SQL should parse")
.expect("test SQL should contain a statement");
let ast::Cmd::Stmt(ast::Stmt::Select(select)) = cmd else {
panic!("expected SELECT statement");
};
match select.body.select {
ast::OneSelect::Select { from, .. } => from,
_ => panic!("expected simple SELECT"),
}
}
#[test]
fn test_rewrite_view_sql_select_table_branch() {
let schema = schema_with_table("CREATE TABLE t (a, b)");
let view_sql = "CREATE VIEW v AS SELECT s.x FROM (SELECT b AS x FROM t) AS s";
let rewritten =
rewrite_view_sql_for_column_rename(view_sql, &schema, "t", "main", "b", "c")
.unwrap()
.expect("view should be rewritten");
assert!(rewritten.sql.contains("SELECT c AS x FROM t"));
}
#[test]
fn test_rewrite_view_sql_sub_branch() {
let schema = schema_with_table("CREATE TABLE t (a, b)");
let view_sql = "CREATE VIEW v AS SELECT s.b FROM (t) AS s";
let rewritten =
rewrite_view_sql_for_column_rename(view_sql, &schema, "t", "main", "b", "c")
.unwrap()
.expect("view should be rewritten");
assert!(!rewritten.sql.contains("s.b"), "{}", rewritten.sql);
}
#[test]
fn test_rewrite_view_sql_table_call_branch() {
let schema = schema_with_table("CREATE TABLE t (a, b)");
let view_sql =
"CREATE VIEW v AS SELECT j.value FROM t JOIN json_each(json_array(t.b)) AS j";
let rewritten =
rewrite_view_sql_for_column_rename(view_sql, &schema, "t", "main", "b", "c")
.unwrap()
.expect("view should be rewritten");
assert!(!rewritten.sql.contains("t.b"), "{}", rewritten.sql);
}
#[test]
fn test_rewrite_view_sql_compound_branch() {
let schema = schema_with_table("CREATE TABLE t (a, b)");
let view_sql = "CREATE VIEW v AS SELECT b FROM t UNION ALL SELECT b FROM t ORDER BY b";
let rewritten =
rewrite_view_sql_for_column_rename(view_sql, &schema, "t", "main", "b", "c")
.unwrap()
.expect("view should be rewritten");
assert_eq!(rewritten.sql.matches("SELECT c FROM t").count(), 2);
assert!(!rewritten.sql.contains("ORDER BY b"), "{}", rewritten.sql);
assert!(rewritten.sql.contains("ORDER BY c"), "{}", rewritten.sql);
}
#[test]
fn test_rewrite_view_sql_cte_branch() {
let schema = schema_with_table("CREATE TABLE t (a, b)");
let view_sql = "CREATE VIEW v AS WITH cte AS (SELECT b FROM t) SELECT b FROM cte";
let rewritten =
rewrite_view_sql_for_column_rename(view_sql, &schema, "t", "main", "b", "c")
.unwrap()
.expect("view should be rewritten");
assert!(
rewritten.sql.contains("WITH cte AS (SELECT c FROM t)"),
"{}",
rewritten.sql
);
assert!(
rewritten.sql.contains("SELECT c FROM cte"),
"{}",
rewritten.sql
);
}
#[test]
fn test_rewrite_view_sql_cte_branch_with_explicit_columns() {
let schema = schema_with_table("CREATE TABLE t (a, b)");
let view_sql = "CREATE VIEW v AS WITH cte(x) AS (SELECT b FROM t) SELECT x FROM cte";
let rewritten =
rewrite_view_sql_for_column_rename(view_sql, &schema, "t", "main", "b", "c")
.unwrap()
.expect("view should be rewritten");
assert!(
rewritten.sql.contains("WITH cte(x)") || rewritten.sql.contains("WITH cte (x)"),
"{}",
rewritten.sql
);
assert!(
rewritten.sql.contains("AS (SELECT c FROM t)"),
"{}",
rewritten.sql
);
assert!(
rewritten.sql.contains("SELECT x FROM cte"),
"{}",
rewritten.sql
);
}
#[test]
fn test_rewrite_trigger_cmd_table_refs_cte_branch() {
let sql = "CREATE TEMP TRIGGER trg AFTER INSERT ON temp.old BEGIN WITH cte AS (SELECT * FROM temp.old) SELECT * FROM cte; END";
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser
.next_cmd()
.expect("trigger SQL should parse")
.expect("trigger SQL should produce a statement");
let ast::Cmd::Stmt(ast::Stmt::CreateTrigger { commands, .. }) = cmd else {
panic!("expected CREATE TRIGGER statement");
};
let mut commands = commands;
let ast::TriggerCmd::Select(select) = &mut commands[0] else {
panic!("expected SELECT trigger command");
};
rewrite_select_table_refs(select, "old", "new");
let Some(with_clause) = &select.with else {
panic!("expected WITH clause");
};
let ast::OneSelect::Select {
from: Some(from), ..
} = &with_clause.ctes[0].select.body.select
else {
panic!("expected CTE SELECT core");
};
let ast::SelectTable::Table(tbl_name, _, _) = from.select.as_ref() else {
panic!("expected CTE SELECT FROM table");
};
assert_eq!(
tbl_name.db_name.as_ref().map(ast::Name::as_str),
Some("temp")
);
assert_eq!(tbl_name.name.as_str(), "new");
}
#[test]
fn test_from_clause_target_qualifiers_dedups_case_insensitively() {
let from = parse_select_from("SELECT 1 FROM Target AS tgt, target AS TARGET");
assert_eq!(
from_clause_target_qualifiers(&from, "target"),
vec!["target".to_string(), "tgt".to_string()]
);
}
#[test]
fn test_extend_qualifiers_scoped_preserves_first_seen_order() {
let mut qualifiers = vec!["target".to_string(), "outer_alias".to_string()];
let local = vec![
"outer_alias".to_string(),
"local_alias".to_string(),
"target".to_string(),
];
let added = extend_qualifiers_scoped(&mut qualifiers, &local);
assert_eq!(
qualifiers,
vec![
"target".to_string(),
"outer_alias".to_string(),
"local_alias".to_string(),
]
);
assert_eq!(added, 1);
qualifiers.truncate(qualifiers.len() - added);
assert_eq!(
qualifiers,
vec!["target".to_string(), "outer_alias".to_string(),]
);
}
#[test]
fn test_rewrite_trigger_cmd_table_refs_expr_subquery_branch() {
let sql =
"CREATE TEMP TRIGGER trg AFTER INSERT ON temp.old BEGIN SELECT EXISTS(SELECT 1 FROM temp.old); END";
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser
.next_cmd()
.expect("trigger SQL should parse")
.expect("trigger SQL should produce a statement");
let ast::Cmd::Stmt(ast::Stmt::CreateTrigger { commands, .. }) = cmd else {
panic!("expected CREATE TRIGGER statement");
};
let mut commands = commands;
let ast::TriggerCmd::Select(select) = &mut commands[0] else {
panic!("expected SELECT trigger command");
};
rewrite_select_table_refs(select, "old", "new");
let ast::OneSelect::Select { columns, .. } = &select.body.select else {
panic!("expected SELECT core");
};
let ast::ResultColumn::Expr(expr, _) = &columns[0] else {
panic!("expected expression result column");
};
let ast::Expr::Exists(subquery) = expr.as_ref() else {
panic!("expected EXISTS expression");
};
let ast::OneSelect::Select {
from: Some(from), ..
} = &subquery.body.select
else {
panic!("expected EXISTS subquery FROM clause");
};
let ast::SelectTable::Table(tbl_name, _, _) = from.select.as_ref() else {
panic!("expected EXISTS subquery FROM table");
};
assert_eq!(
tbl_name.db_name.as_ref().map(ast::Name::as_str),
Some("temp")
);
assert_eq!(tbl_name.name.as_str(), "new");
}
#[test]
fn test_rewrite_view_sql_join_on_branch() {
let schema = schema_with_tables(&["CREATE TABLE t (a, b)", "CREATE TABLE u (b)"]);
let view_sql = "CREATE VIEW v AS SELECT t.a FROM t JOIN u ON t.b = u.b";
let rewritten =
rewrite_view_sql_for_column_rename(view_sql, &schema, "t", "main", "b", "c")
.unwrap()
.expect("view should be rewritten");
assert!(!rewritten.sql.contains("t.b"), "{}", rewritten.sql);
assert!(rewritten.sql.contains("t.c = u.b"), "{}", rewritten.sql);
}
#[test]
fn test_rewrite_view_sql_join_using_branch() {
let schema = schema_with_tables(&["CREATE TABLE t (a, b)", "CREATE TABLE u (b)"]);
let view_sql = "CREATE VIEW v AS SELECT t.a FROM t JOIN u USING (b)";
let rewritten =
rewrite_view_sql_for_column_rename(view_sql, &schema, "t", "main", "b", "c")
.unwrap()
.expect("view should be rewritten");
assert!(
!rewritten.sql.contains("USING (b)") && !rewritten.sql.contains("USING(b)"),
"{}",
rewritten.sql
);
assert!(
rewritten.sql.contains("USING (c)") || rewritten.sql.contains("USING(c)"),
"{}",
rewritten.sql
);
}
#[test]
fn test_rewrite_view_sql_group_by_having_branch() {
let schema = schema_with_table("CREATE TABLE t (a, b)");
let view_sql = "CREATE VIEW v AS SELECT b FROM t GROUP BY b HAVING b > 0";
let rewritten =
rewrite_view_sql_for_column_rename(view_sql, &schema, "t", "main", "b", "c")
.unwrap()
.expect("view should be rewritten");
assert!(
rewritten
.sql
.contains("SELECT c FROM t GROUP BY c HAVING c > 0"),
"{}",
rewritten.sql
);
}
#[test]
fn test_rewrite_view_sql_window_clause_branch() {
let schema = schema_with_table("CREATE TABLE t (a, b)");
let view_sql = "CREATE VIEW v AS SELECT sum(a) OVER (PARTITION BY b ORDER BY b) FROM t";
let rewritten =
rewrite_view_sql_for_column_rename(view_sql, &schema, "t", "main", "b", "c")
.unwrap()
.expect("view should be rewritten");
assert!(
!rewritten.sql.contains("PARTITION BY b"),
"{}",
rewritten.sql
);
assert!(!rewritten.sql.contains("ORDER BY b"), "{}", rewritten.sql);
assert!(
rewritten.sql.contains("PARTITION BY c"),
"{}",
rewritten.sql
);
assert!(rewritten.sql.contains("ORDER BY c"), "{}", rewritten.sql);
}
#[test]
fn test_rewrite_view_sql_limit_offset_branch() {
let schema = schema_with_table("CREATE TABLE t (a, b)");
let view_sql = "CREATE VIEW v AS SELECT a FROM t LIMIT b OFFSET b";
let rewritten =
rewrite_view_sql_for_column_rename(view_sql, &schema, "t", "main", "b", "c")
.unwrap()
.expect("view should be rewritten");
assert!(!rewritten.sql.contains("LIMIT b"), "{}", rewritten.sql);
assert!(!rewritten.sql.contains("OFFSET b"), "{}", rewritten.sql);
assert!(rewritten.sql.contains("LIMIT c"), "{}", rewritten.sql);
assert!(rewritten.sql.contains("OFFSET c"), "{}", rewritten.sql);
}
#[test]
fn test_rewrite_view_sql_values_branch() {
let schema = schema_with_table("CREATE TABLE t (a, b)");
let view_sql = "CREATE VIEW v AS VALUES ((SELECT b FROM t LIMIT 1))";
let rewritten =
rewrite_view_sql_for_column_rename(view_sql, &schema, "t", "main", "b", "c")
.unwrap()
.expect("view should be rewritten");
assert!(
rewritten.sql.contains("VALUES ((SELECT c FROM t LIMIT 1))"),
"{}",
rewritten.sql
);
}
#[test]
fn test_indexed_variable_comparison() {
let expr1 = Expr::Variable(Variable::indexed(1u32.try_into().unwrap()));
let expr2 = Expr::Variable(Variable::indexed(1u32.try_into().unwrap()));
assert!(exprs_are_equivalent(&expr1, &expr2));
}
#[test]
fn test_named_variable_comparison() {
let expr1 = Expr::Variable(Variable::named(":a".to_string(), 1u32.try_into().unwrap()));
let expr2 = Expr::Variable(Variable::named(":a".to_string(), 1u32.try_into().unwrap()));
assert!(exprs_are_equivalent(&expr1, &expr2));
let expr1 = Expr::Variable(Variable::named(":a".to_string(), 1u32.try_into().unwrap()));
let expr2 = Expr::Variable(Variable::named(":b".to_string(), 2u32.try_into().unwrap()));
assert!(!exprs_are_equivalent(&expr1, &expr2));
}
#[test]
fn test_basic_addition_exprs_are_equivalent() {
let expr1 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("826".to_string()))),
Add,
Box::new(Expr::Literal(Literal::Numeric("389".to_string()))),
);
let expr2 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("389".to_string()))),
Add,
Box::new(Expr::Literal(Literal::Numeric("826".to_string()))),
);
assert!(exprs_are_equivalent(&expr1, &expr2));
}
#[test]
fn test_addition_expressions_equivalent_normalized() {
let expr1 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("123.0".to_string()))),
Add,
Box::new(Expr::Literal(Literal::Numeric("243.0".to_string()))),
);
let expr2 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("243.0".to_string()))),
Add,
Box::new(Expr::Literal(Literal::Numeric("123.0".to_string()))),
);
assert!(exprs_are_equivalent(&expr1, &expr2));
let expr3 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("123.0".to_string()))),
Add,
Box::new(Expr::Literal(Literal::Numeric("243".to_string()))),
);
let expr4 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("243.0".to_string()))),
Add,
Box::new(Expr::Literal(Literal::Numeric("123".to_string()))),
);
assert!(!exprs_are_equivalent(&expr3, &expr4));
}
#[test]
fn test_subtraction_expressions_not_equivalent() {
let expr3 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("364".to_string()))),
Subtract,
Box::new(Expr::Literal(Literal::Numeric("22.0".to_string()))),
);
let expr4 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("22.0".to_string()))),
Subtract,
Box::new(Expr::Literal(Literal::Numeric("364".to_string()))),
);
assert!(!exprs_are_equivalent(&expr3, &expr4));
}
#[test]
fn test_subtraction_expressions_normalized() {
let expr3 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("66.0".to_string()))),
Subtract,
Box::new(Expr::Literal(Literal::Numeric("22.0".to_string()))),
);
let expr4 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("66.0".to_string()))),
Subtract,
Box::new(Expr::Literal(Literal::Numeric("22.0".to_string()))),
);
assert!(exprs_are_equivalent(&expr3, &expr4));
let expr5 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("66.0".to_string()))),
Subtract,
Box::new(Expr::Literal(Literal::Numeric("22".to_string()))),
);
let expr6 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("66".to_string()))),
Subtract,
Box::new(Expr::Literal(Literal::Numeric("22.0".to_string()))),
);
assert!(!exprs_are_equivalent(&expr5, &expr6));
}
#[test]
fn test_expressions_equivalent_case_insensitive_functioncalls() {
let func1 = Expr::FunctionCall {
name: Name::exact("SUM".to_string()),
distinctness: None,
args: vec![Expr::Id(Name::exact("x".to_string())).into()],
order_by: vec![],
filter_over: FunctionTail {
filter_clause: None,
over_clause: None,
},
};
let func2 = Expr::FunctionCall {
name: Name::exact("sum".to_string()),
distinctness: None,
args: vec![Expr::Id(Name::exact("x".to_string())).into()],
order_by: vec![],
filter_over: FunctionTail {
filter_clause: None,
over_clause: None,
},
};
assert!(exprs_are_equivalent(&func1, &func2));
let func3 = Expr::FunctionCall {
name: Name::exact("SUM".to_string()),
distinctness: Some(ast::Distinctness::Distinct),
args: vec![Expr::Id(Name::exact("x".to_string())).into()],
order_by: vec![],
filter_over: FunctionTail {
filter_clause: None,
over_clause: None,
},
};
assert!(!exprs_are_equivalent(&func1, &func3));
}
#[test]
fn test_expressions_equivalent_identical_fn_with_distinct() {
let sum = Expr::FunctionCall {
name: Name::exact("SUM".to_string()),
distinctness: None,
args: vec![Expr::Id(Name::exact("x".to_string())).into()],
order_by: vec![],
filter_over: FunctionTail {
filter_clause: None,
over_clause: None,
},
};
let sum_distinct = Expr::FunctionCall {
name: Name::exact("SUM".to_string()),
distinctness: Some(ast::Distinctness::Distinct),
args: vec![Expr::Id(Name::exact("x".to_string())).into()],
order_by: vec![],
filter_over: FunctionTail {
filter_clause: None,
over_clause: None,
},
};
assert!(!exprs_are_equivalent(&sum, &sum_distinct));
}
#[test]
fn test_expressions_equivalent_multiplication() {
let expr1 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("42.0".to_string()))),
Multiply,
Box::new(Expr::Literal(Literal::Numeric("38.0".to_string()))),
);
let expr2 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("38.0".to_string()))),
Multiply,
Box::new(Expr::Literal(Literal::Numeric("42.0".to_string()))),
);
assert!(exprs_are_equivalent(&expr1, &expr2));
}
#[test]
fn test_expressions_both_parenthesized_equivalent() {
let expr1 = Expr::Parenthesized(vec![Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("683".to_string()))),
Add,
Box::new(Expr::Literal(Literal::Numeric("799".to_string()))),
)
.into()]);
let expr2 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("799".to_string()))),
Add,
Box::new(Expr::Literal(Literal::Numeric("683".to_string()))),
);
assert!(exprs_are_equivalent(&expr1, &expr2));
}
#[test]
fn test_expressions_parenthesized_equivalent() {
let expr7 = Expr::Parenthesized(vec![Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("6".to_string()))),
Add,
Box::new(Expr::Literal(Literal::Numeric("7".to_string()))),
)
.into()]);
let expr8 = Expr::Binary(
Box::new(Expr::Literal(Literal::Numeric("6".to_string()))),
Add,
Box::new(Expr::Literal(Literal::Numeric("7".to_string()))),
);
assert!(exprs_are_equivalent(&expr7, &expr8));
}
#[test]
fn test_like_expressions_equivalent() {
let expr1 = Expr::Like {
lhs: Box::new(Expr::Id(Name::exact("name".to_string()))),
not: false,
op: ast::LikeOperator::Like,
rhs: Box::new(Expr::Literal(Literal::String("%john%".to_string()))),
escape: Some(Box::new(Expr::Literal(Literal::String("\\".to_string())))),
};
let expr2 = Expr::Like {
lhs: Box::new(Expr::Id(Name::exact("name".to_string()))),
not: false,
op: ast::LikeOperator::Like,
rhs: Box::new(Expr::Literal(Literal::String("%john%".to_string()))),
escape: Some(Box::new(Expr::Literal(Literal::String("\\".to_string())))),
};
assert!(exprs_are_equivalent(&expr1, &expr2));
}
#[test]
fn test_expressions_equivalent_like_escaped() {
let expr1 = Expr::Like {
lhs: Box::new(Expr::Id(Name::exact("name".to_string()))),
not: false,
op: ast::LikeOperator::Like,
rhs: Box::new(Expr::Literal(Literal::String("%john%".to_string()))),
escape: Some(Box::new(Expr::Literal(Literal::String("\\".to_string())))),
};
let expr2 = Expr::Like {
lhs: Box::new(Expr::Id(Name::exact("name".to_string()))),
not: false,
op: ast::LikeOperator::Like,
rhs: Box::new(Expr::Literal(Literal::String("%john%".to_string()))),
escape: Some(Box::new(Expr::Literal(Literal::String("#".to_string())))),
};
assert!(!exprs_are_equivalent(&expr1, &expr2));
}
#[test]
fn test_expressions_equivalent_between() {
let expr1 = Expr::Between {
lhs: Box::new(Expr::Id(Name::exact("age".to_string()))),
not: false,
start: Box::new(Expr::Literal(Literal::Numeric("18".to_string()))),
end: Box::new(Expr::Literal(Literal::Numeric("65".to_string()))),
};
let expr2 = Expr::Between {
lhs: Box::new(Expr::Id(Name::exact("age".to_string()))),
not: false,
start: Box::new(Expr::Literal(Literal::Numeric("18".to_string()))),
end: Box::new(Expr::Literal(Literal::Numeric("65".to_string()))),
};
assert!(exprs_are_equivalent(&expr1, &expr2));
let expr3 = Expr::Between {
lhs: Box::new(Expr::Id(Name::exact("age".to_string()))),
not: false,
start: Box::new(Expr::Literal(Literal::Numeric("20".to_string()))),
end: Box::new(Expr::Literal(Literal::Numeric("65".to_string()))),
};
assert!(!exprs_are_equivalent(&expr1, &expr3));
}
#[test]
fn test_cast_exprs_equivalent() {
let cast1 = Expr::Cast {
expr: Box::new(Expr::Literal(Literal::Numeric("123".to_string()))),
type_name: Some(Type {
name: "INTEGER".to_string(),
size: None,
array_dimensions: 0,
}),
};
let cast2 = Expr::Cast {
expr: Box::new(Expr::Literal(Literal::Numeric("123".to_string()))),
type_name: Some(Type {
name: "integer".to_string(),
size: None,
array_dimensions: 0,
}),
};
assert!(exprs_are_equivalent(&cast1, &cast2));
}
#[test]
fn test_ident_equivalency() {
assert!(check_ident_equivalency("\"foo\"", "foo"));
assert!(check_ident_equivalency("[foo]", "foo"));
assert!(check_ident_equivalency("`FOO`", "foo"));
assert!(check_ident_equivalency("\"foo\"", "`FOO`"));
assert!(!check_ident_equivalency("\"foo\"", "[bar]"));
assert!(!check_ident_equivalency("foo", "\"bar\""));
}
#[test]
fn test_simple_uri() {
let uri = "file:/home/user/db.sqlite";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/home/user/db.sqlite");
assert_eq!(opts.authority, None);
}
#[test]
fn test_uri_with_authority() {
let uri = "file://localhost/home/user/db.sqlite";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/home/user/db.sqlite");
assert_eq!(opts.authority, Some("localhost"));
}
#[test]
fn test_uri_with_invalid_authority() {
let uri = "file://example.com/home/user/db.sqlite";
let result = OpenOptions::parse(uri);
assert!(result.is_err());
}
#[test]
fn test_uri_with_query_params() {
let uri = "file:/home/user/db.sqlite?vfs=unix&mode=ro&immutable=1";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/home/user/db.sqlite");
assert_eq!(opts.vfs, Some("unix".to_string()));
assert_eq!(opts.mode, OpenMode::ReadOnly);
assert!(opts.immutable);
}
#[test]
fn test_uri_with_fragment() {
let uri = "file:/home/user/db.sqlite#section1";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/home/user/db.sqlite");
}
#[test]
fn test_uri_with_percent_encoding() {
let uri = "file:/home/user/db%20with%20spaces.sqlite?vfs=unix";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/home/user/db with spaces.sqlite");
assert_eq!(opts.vfs, Some("unix".to_string()));
}
#[test]
fn test_uri_without_scheme() {
let uri = "/home/user/db.sqlite";
let result = OpenOptions::parse(uri);
assert!(result.is_ok());
assert_eq!(result.unwrap().path, "/home/user/db.sqlite");
}
#[test]
fn test_uri_with_empty_query() {
let uri = "file:/home/user/db.sqlite?";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/home/user/db.sqlite");
assert_eq!(opts.vfs, None);
}
#[test]
fn test_uri_with_partial_query() {
let uri = "file:/home/user/db.sqlite?mode=rw";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/home/user/db.sqlite");
assert_eq!(opts.mode, OpenMode::ReadWrite);
assert_eq!(opts.vfs, None);
}
#[test]
fn test_uri_windows_style_path() {
let uri = "file:///C:/Users/test/db.sqlite";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/C:/Users/test/db.sqlite");
}
#[test]
fn test_uri_with_only_query_params() {
let uri = "file:?mode=memory&cache=shared";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "");
assert_eq!(opts.mode, OpenMode::Memory);
assert_eq!(opts.cache, CacheMode::Shared);
}
#[test]
fn test_uri_with_only_fragment() {
let uri = "file:#fragment";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "");
}
#[test]
fn test_uri_with_invalid_scheme() {
let uri = "http:/home/user/db.sqlite";
let result = OpenOptions::parse(uri);
assert!(result.is_ok());
assert_eq!(result.unwrap().path, "http:/home/user/db.sqlite");
}
#[test]
fn test_uri_with_multiple_query_params() {
let uri = "file:/home/user/db.sqlite?vfs=unix&mode=rw&cache=private&immutable=0";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/home/user/db.sqlite");
assert_eq!(opts.vfs, Some("unix".to_string()));
assert_eq!(opts.mode, OpenMode::ReadWrite);
assert_eq!(opts.cache, CacheMode::Private);
assert!(!opts.immutable);
}
#[test]
fn test_uri_with_unknown_query_param() {
let uri = "file:/home/user/db.sqlite?unknown=param";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/home/user/db.sqlite");
assert_eq!(opts.vfs, None);
}
#[test]
fn test_uri_with_multiple_equal_signs() {
let uri = "file:/home/user/db.sqlite?vfs=unix=custom";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/home/user/db.sqlite");
assert_eq!(opts.vfs, Some("unix=custom".to_string()));
}
#[test]
fn test_uri_with_trailing_slash() {
let uri = "file:/home/user/db.sqlite/";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/home/user/db.sqlite/");
}
#[test]
fn test_uri_with_encoded_characters_in_query() {
let uri = "file:/home/user/db.sqlite?vfs=unix%20mode";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/home/user/db.sqlite");
assert_eq!(opts.vfs, Some("unix mode".to_string()));
}
#[test]
fn test_uri_windows_network_path() {
let uri = "file://server/share/db.sqlite";
let result = OpenOptions::parse(uri);
assert!(result.is_err()); }
#[test]
fn test_uri_windows_drive_letter_with_slash() {
let uri = "file:///C:/database.sqlite";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/C:/database.sqlite");
}
#[test]
fn test_localhost_with_double_slash_and_no_path() {
let uri = "file://localhost";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "");
assert_eq!(opts.authority, Some("localhost"));
}
#[test]
fn test_uri_windows_drive_letter_without_slash() {
let uri = "file:///C:/database.sqlite";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/C:/database.sqlite");
}
#[test]
fn test_improper_mode() {
let uri = "file:data.db?mode=readonly";
let res = OpenOptions::parse(uri);
assert!(res.is_err());
let uri = "file:/home/user/db.sqlite?vfs=&mode=";
let res = OpenOptions::parse(uri);
assert!(res.is_err());
}
#[test]
fn test_simple_file_current_dir() {
let uri = "file:data.db";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "data.db");
assert_eq!(opts.authority, None);
assert_eq!(opts.vfs, None);
assert_eq!(opts.mode, OpenMode::ReadWriteCreate);
}
#[test]
fn test_simple_file_three_slash() {
let uri = "file:///home/data/data.db";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/home/data/data.db");
assert_eq!(opts.authority, None);
assert_eq!(opts.vfs, None);
assert_eq!(opts.mode, OpenMode::ReadWriteCreate);
}
#[test]
fn test_simple_file_two_slash_localhost() {
let uri = "file://localhost/home/fred/data.db";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/home/fred/data.db");
assert_eq!(opts.authority, Some("localhost"));
assert_eq!(opts.vfs, None);
}
#[test]
fn test_windows_double_invalid() {
let uri = "file://C:/home/fred/data.db?mode=ro";
let opts = OpenOptions::parse(uri);
assert!(opts.is_err());
}
#[test]
fn test_simple_file_two_slash() {
let uri = "file:///C:/Documents%20and%20Settings/fred/Desktop/data.db";
let opts = OpenOptions::parse(uri).unwrap();
assert_eq!(opts.path, "/C:/Documents and Settings/fred/Desktop/data.db");
assert_eq!(opts.vfs, None);
}
#[test]
fn test_decode_percent_basic() {
assert_eq!(decode_percent("hello%20world"), "hello world");
assert_eq!(decode_percent("file%3Adata.db"), "file:data.db");
assert_eq!(decode_percent("path%2Fto%2Ffile"), "path/to/file");
}
#[test]
fn test_decode_percent_edge_cases() {
assert_eq!(decode_percent(""), "");
assert_eq!(decode_percent("plain_text"), "plain_text");
assert_eq!(
decode_percent("%2Fhome%2Fuser%2Fdb.sqlite"),
"/home/user/db.sqlite"
);
assert_eq!(decode_percent("%41%42%43"), "ABC");
assert_eq!(decode_percent("%61%62%63"), "abc");
}
#[test]
fn test_decode_percent_invalid_sequences() {
assert_eq!(decode_percent("hello%"), "hello%");
assert_eq!(decode_percent("file%2"), "file%2");
assert_eq!(decode_percent("file%2X.db"), "file%2X.db");
assert_eq!(decode_percent("path%2Fto%2"), "path/to%2");
}
#[test]
fn test_decode_percent_mixed_valid_invalid() {
assert_eq!(decode_percent("hello%20world%"), "hello world%");
assert_eq!(decode_percent("%2Fpath%2Xto%2Ffile"), "/path%2Xto/file");
assert_eq!(decode_percent("file%3Adata.db%2"), "file:data.db%2");
}
#[test]
fn test_decode_percent_special_characters() {
assert_eq!(
decode_percent("%21%40%23%24%25%5E%26%2A%28%29"),
"!@#$%^&*()"
);
assert_eq!(decode_percent("%5B%5D%7B%7D%7C%5C%3A"), "[]{}|\\:");
}
#[test]
fn test_decode_percent_unmodified_valid_text() {
assert_eq!(
decode_percent("C:/Users/Example/Database.sqlite"),
"C:/Users/Example/Database.sqlite"
);
assert_eq!(
decode_percent("/home/user/db.sqlite"),
"/home/user/db.sqlite"
);
}
#[test]
fn test_text_to_integer() {
assert_eq!(
checked_cast_text_to_numeric("1", false).unwrap(),
Value::from_i64(1)
);
assert_eq!(
checked_cast_text_to_numeric("-1", false).unwrap(),
Value::from_i64(-1)
);
assert_eq!(
checked_cast_text_to_numeric("1823400-00000", false).unwrap(),
Value::from_i64(1823400)
);
assert_eq!(
checked_cast_text_to_numeric("-10000000", false).unwrap(),
Value::from_i64(-10000000)
);
assert_eq!(
checked_cast_text_to_numeric("123xxx", false).unwrap(),
Value::from_i64(123)
);
assert_eq!(
checked_cast_text_to_numeric("9223372036854775807", false).unwrap(),
Value::from_i64(i64::MAX)
);
assert_eq!(
checked_cast_text_to_numeric("9223372036854775808", false).unwrap(),
Value::from_f64(9.22337203685478e18)
);
assert_eq!(
checked_cast_text_to_numeric("-9223372036854775808", false).unwrap(),
Value::from_i64(i64::MIN)
);
assert_eq!(
checked_cast_text_to_numeric("-9223372036854775809", false).unwrap(),
Value::from_f64(-9.22337203685478e18)
);
assert!(checked_cast_text_to_numeric("-", false).is_err());
}
#[test]
fn test_text_to_real() {
assert_eq!(
checked_cast_text_to_numeric("1", false).unwrap(),
Value::from_i64(1)
);
assert_eq!(
checked_cast_text_to_numeric("-1", false).unwrap(),
Value::from_i64(-1)
);
assert_eq!(
checked_cast_text_to_numeric("1.0", false).unwrap(),
Value::from_i64(1)
);
assert_eq!(
checked_cast_text_to_numeric("-1.0", false).unwrap(),
Value::from_i64(-1)
);
assert_eq!(
checked_cast_text_to_numeric("1e10", false).unwrap(),
Value::from_i64(10_000_000_000)
);
assert_eq!(
checked_cast_text_to_numeric("-1e10", false).unwrap(),
Value::from_i64(-10_000_000_000)
);
assert_eq!(
checked_cast_text_to_numeric("1e-10", false).unwrap(),
Value::from_f64(1e-10)
);
assert_eq!(
checked_cast_text_to_numeric("-1e-10", false).unwrap(),
Value::from_f64(-1e-10)
);
assert_eq!(
checked_cast_text_to_numeric("1.123e10", false).unwrap(),
Value::from_i64(11_230_000_000)
);
assert_eq!(
checked_cast_text_to_numeric("-1.123e10", false).unwrap(),
Value::from_i64(-11_230_000_000)
);
assert_eq!(
checked_cast_text_to_numeric("1.123e-10", false).unwrap(),
Value::from_f64(1.123e-10)
);
assert_eq!(
checked_cast_text_to_numeric("-1.123-e-10", false).unwrap(),
Value::from_f64(-1.123)
);
assert_eq!(
checked_cast_text_to_numeric("1-282584294928", false).unwrap(),
Value::from_i64(1)
);
assert_eq!(
checked_cast_text_to_numeric("1.7976931348623157e309", false).unwrap(),
Value::from_f64(f64::INFINITY),
);
assert_eq!(
checked_cast_text_to_numeric("-1.7976931348623157e308", false).unwrap(),
Value::from_f64(f64::MIN),
);
assert_eq!(
checked_cast_text_to_numeric("-1.7976931348623157e309", false).unwrap(),
Value::from_f64(f64::NEG_INFINITY),
);
assert_eq!(
checked_cast_text_to_numeric("1E", false).unwrap(),
Value::from_i64(1)
);
assert_eq!(
checked_cast_text_to_numeric("1EE", false).unwrap(),
Value::from_i64(1)
);
assert_eq!(
checked_cast_text_to_numeric("-1E", false).unwrap(),
Value::from_i64(-1)
);
assert_eq!(
checked_cast_text_to_numeric("1.", false).unwrap(),
Value::from_i64(1)
);
assert_eq!(
checked_cast_text_to_numeric("-1.", false).unwrap(),
Value::from_i64(-1)
);
assert_eq!(
checked_cast_text_to_numeric("1.23E", false).unwrap(),
Value::from_f64(1.23)
);
assert_eq!(
checked_cast_text_to_numeric(".1.23E-", false).unwrap(),
Value::from_f64(0.1)
);
assert_eq!(
checked_cast_text_to_numeric("0", false).unwrap(),
Value::from_i64(0)
);
assert_eq!(
checked_cast_text_to_numeric("-0", false).unwrap(),
Value::from_i64(0)
);
assert_eq!(
checked_cast_text_to_numeric("-0", false).unwrap(),
Value::from_i64(0)
);
assert_eq!(
checked_cast_text_to_numeric("-0.0", false).unwrap(),
Value::from_i64(0)
);
assert_eq!(
checked_cast_text_to_numeric("0.0", false).unwrap(),
Value::from_i64(0)
);
assert!(checked_cast_text_to_numeric("-", false).is_err());
}
#[test]
fn test_text_to_numeric() {
assert_eq!(
checked_cast_text_to_numeric("1", false).unwrap(),
Value::from_i64(1)
);
assert_eq!(
checked_cast_text_to_numeric("-1", false).unwrap(),
Value::from_i64(-1)
);
assert_eq!(
checked_cast_text_to_numeric("1823400-00000", false).unwrap(),
Value::from_i64(1823400)
);
assert_eq!(
checked_cast_text_to_numeric("-10000000", false).unwrap(),
Value::from_i64(-10000000)
);
assert_eq!(
checked_cast_text_to_numeric("123xxx", false).unwrap(),
Value::from_i64(123)
);
assert_eq!(
checked_cast_text_to_numeric("9223372036854775807", false).unwrap(),
Value::from_i64(i64::MAX)
);
assert_eq!(
checked_cast_text_to_numeric("9223372036854775808", false).unwrap(),
Value::from_f64(9.22337203685478e18)
); assert_eq!(
checked_cast_text_to_numeric("-9223372036854775808", false).unwrap(),
Value::from_i64(i64::MIN)
);
assert_eq!(
checked_cast_text_to_numeric("-9223372036854775809", false).unwrap(),
Value::from_f64(-9.22337203685478e18)
);
assert_eq!(
checked_cast_text_to_numeric("1.0", false).unwrap(),
Value::from_i64(1)
);
assert_eq!(
checked_cast_text_to_numeric("-1.0", false).unwrap(),
Value::from_i64(-1)
);
assert_eq!(
checked_cast_text_to_numeric("1e10", false).unwrap(),
Value::from_i64(10_000_000_000)
);
assert_eq!(
checked_cast_text_to_numeric("-1e10", false).unwrap(),
Value::from_i64(-10_000_000_000)
);
assert_eq!(
checked_cast_text_to_numeric("1e-10", false).unwrap(),
Value::from_f64(1e-10)
);
assert_eq!(
checked_cast_text_to_numeric("-1e-10", false).unwrap(),
Value::from_f64(-1e-10)
);
assert_eq!(
checked_cast_text_to_numeric("1.123e10", false).unwrap(),
Value::from_i64(11_230_000_000)
);
assert_eq!(
checked_cast_text_to_numeric("-1.123e10", false).unwrap(),
Value::from_i64(-11_230_000_000)
);
assert_eq!(
checked_cast_text_to_numeric("1.123e-10", false).unwrap(),
Value::from_f64(1.123e-10)
);
assert_eq!(
checked_cast_text_to_numeric("-1.123-e-10", false).unwrap(),
Value::from_f64(-1.123)
);
assert_eq!(
checked_cast_text_to_numeric("1-282584294928", false).unwrap(),
Value::from_i64(1)
);
assert!(checked_cast_text_to_numeric("xxx", false).is_err());
assert_eq!(
checked_cast_text_to_numeric("1.7976931348623157e309", false).unwrap(),
Value::from_f64(f64::INFINITY)
);
assert_eq!(
checked_cast_text_to_numeric("-1.7976931348623157e308", false).unwrap(),
Value::from_f64(f64::MIN)
);
assert_eq!(
checked_cast_text_to_numeric("-1.7976931348623157e309", false).unwrap(),
Value::from_f64(f64::NEG_INFINITY)
);
assert_eq!(
checked_cast_text_to_numeric("1E", false).unwrap(),
Value::from_i64(1)
);
assert_eq!(
checked_cast_text_to_numeric("1EE", false).unwrap(),
Value::from_i64(1)
);
assert_eq!(
checked_cast_text_to_numeric("-1E", false).unwrap(),
Value::from_i64(-1)
);
assert_eq!(
checked_cast_text_to_numeric("1.", false).unwrap(),
Value::from_i64(1)
);
assert_eq!(
checked_cast_text_to_numeric("-1.", false).unwrap(),
Value::from_i64(-1)
);
assert_eq!(
checked_cast_text_to_numeric("1.23E", false).unwrap(),
Value::from_f64(1.23)
);
assert_eq!(
checked_cast_text_to_numeric("1.23E-", false).unwrap(),
Value::from_f64(1.23)
);
assert_eq!(
checked_cast_text_to_numeric("0", false).unwrap(),
Value::from_i64(0)
);
assert_eq!(
checked_cast_text_to_numeric("-0", false).unwrap(),
Value::from_i64(0)
);
assert_eq!(
checked_cast_text_to_numeric("-0.0", false).unwrap(),
Value::from_i64(0)
);
assert_eq!(
checked_cast_text_to_numeric("0.0", false).unwrap(),
Value::from_i64(0)
);
assert!(checked_cast_text_to_numeric("-", false).is_err());
assert_eq!(
checked_cast_text_to_numeric("-e", false).unwrap(),
Value::from_f64(0.0)
);
assert_eq!(
checked_cast_text_to_numeric("-E", false).unwrap(),
Value::from_f64(0.0)
);
}
#[test]
fn test_parse_numeric_str_valid_integer() {
assert_eq!(parse_numeric_str("123"), Ok((ValueType::Integer, "123")));
assert_eq!(parse_numeric_str("-456"), Ok((ValueType::Integer, "-456")));
assert_eq!(parse_numeric_str("+789"), Ok((ValueType::Integer, "+789")));
assert_eq!(
parse_numeric_str("000789"),
Ok((ValueType::Integer, "000789"))
);
}
#[test]
fn test_parse_numeric_str_valid_float() {
assert_eq!(
parse_numeric_str("123.456"),
Ok((ValueType::Float, "123.456"))
);
assert_eq!(
parse_numeric_str("-0.789"),
Ok((ValueType::Float, "-0.789"))
);
assert_eq!(
parse_numeric_str("+0.789"),
Ok((ValueType::Float, "+0.789"))
);
assert_eq!(parse_numeric_str("1e10"), Ok((ValueType::Float, "1e10")));
assert_eq!(parse_numeric_str("+1e10"), Ok((ValueType::Float, "+1e10")));
assert_eq!(
parse_numeric_str("-1.23e-4"),
Ok((ValueType::Float, "-1.23e-4"))
);
assert_eq!(
parse_numeric_str("1.23E+4"),
Ok((ValueType::Float, "1.23E+4"))
);
assert_eq!(parse_numeric_str("1.2.3"), Ok((ValueType::Float, "1.2")))
}
#[test]
fn test_parse_numeric_str_edge_cases() {
assert_eq!(parse_numeric_str("1e"), Ok((ValueType::Float, "1")));
assert_eq!(parse_numeric_str("1e-"), Ok((ValueType::Float, "1")));
assert_eq!(parse_numeric_str("1e+"), Ok((ValueType::Float, "1")));
assert_eq!(parse_numeric_str("-1e"), Ok((ValueType::Float, "-1")));
assert_eq!(parse_numeric_str("-1e-"), Ok((ValueType::Float, "-1")));
}
#[test]
fn test_parse_numeric_str_invalid() {
assert_eq!(parse_numeric_str(""), Err(()));
assert_eq!(parse_numeric_str("abc"), Err(()));
assert_eq!(parse_numeric_str("-"), Err(()));
assert_eq!(parse_numeric_str("+"), Err(()));
assert_eq!(parse_numeric_str("e10"), Err(()));
assert_eq!(parse_numeric_str(".e10"), Err(()));
}
#[test]
fn test_parse_numeric_str_with_whitespace() {
assert_eq!(parse_numeric_str(" 123"), Ok((ValueType::Integer, "123")));
assert_eq!(
parse_numeric_str(" -456.78 "),
Ok((ValueType::Float, "-456.78"))
);
assert_eq!(
parse_numeric_str(" 1.23e4 "),
Ok((ValueType::Float, "1.23e4"))
);
}
#[test]
fn test_parse_numeric_str_leading_zeros() {
assert_eq!(
parse_numeric_str("000123"),
Ok((ValueType::Integer, "000123"))
);
assert_eq!(
parse_numeric_str("000.456"),
Ok((ValueType::Float, "000.456"))
);
assert_eq!(
parse_numeric_str("0001e3"),
Ok((ValueType::Float, "0001e3"))
);
}
#[test]
fn test_parse_numeric_str_trailing_characters() {
assert_eq!(parse_numeric_str("123abc"), Ok((ValueType::Integer, "123")));
assert_eq!(
parse_numeric_str("456.78xyz"),
Ok((ValueType::Float, "456.78"))
);
assert_eq!(
parse_numeric_str("1.23e4extra"),
Ok((ValueType::Float, "1.23e4"))
);
}
#[test]
fn test_module_name_basic() {
let sql = "CREATE VIRTUAL TABLE x USING y;";
assert_eq!(module_name_from_sql(sql).unwrap(), "y");
}
#[test]
fn test_module_name_with_args() {
let sql = "CREATE VIRTUAL TABLE x USING modname('a', 'b');";
assert_eq!(module_name_from_sql(sql).unwrap(), "modname");
}
#[test]
fn test_module_name_missing_using() {
let sql = "CREATE VIRTUAL TABLE x (a, b);";
assert!(module_name_from_sql(sql).is_err());
}
#[test]
fn test_module_name_no_semicolon() {
let sql = "CREATE VIRTUAL TABLE x USING limbo(a, b)";
assert_eq!(module_name_from_sql(sql).unwrap(), "limbo");
}
#[test]
fn test_module_name_no_semicolon_or_args() {
let sql = "CREATE VIRTUAL TABLE x USING limbo";
assert_eq!(module_name_from_sql(sql).unwrap(), "limbo");
}
#[test]
fn test_module_args_none() {
let sql = "CREATE VIRTUAL TABLE x USING modname;";
let args = module_args_from_sql(sql).unwrap();
assert_eq!(args.len(), 0);
}
#[test]
fn test_module_args_basic() {
let sql = "CREATE VIRTUAL TABLE x USING modname('arg1', 'arg2');";
let args = module_args_from_sql(sql).unwrap();
assert_eq!(args.len(), 2);
assert_eq!("arg1", args[0].to_text().unwrap());
assert_eq!("arg2", args[1].to_text().unwrap());
for arg in args {
unsafe { arg.__free_internal_type() }
}
}
#[test]
fn test_module_args_with_escaped_quote() {
let sql = "CREATE VIRTUAL TABLE x USING modname('a''b', 'c');";
let args = module_args_from_sql(sql).unwrap();
assert_eq!(args.len(), 2);
assert_eq!(args[0].to_text().unwrap(), "a'b");
assert_eq!(args[1].to_text().unwrap(), "c");
for arg in args {
unsafe { arg.__free_internal_type() }
}
}
#[test]
fn test_module_args_unterminated_string() {
let sql = "CREATE VIRTUAL TABLE x USING modname('arg1, 'arg2');";
assert!(module_args_from_sql(sql).is_err());
}
#[test]
fn test_module_args_extra_garbage_after_quote() {
let sql = "CREATE VIRTUAL TABLE x USING modname('arg1'x);";
assert!(module_args_from_sql(sql).is_err());
}
#[test]
fn test_module_args_trailing_comma() {
let sql = "CREATE VIRTUAL TABLE x USING modname('arg1',);";
let args = module_args_from_sql(sql).unwrap();
assert_eq!(args.len(), 1);
assert_eq!("arg1", args[0].to_text().unwrap());
for arg in args {
unsafe { arg.__free_internal_type() }
}
}
#[test]
fn test_parse_numeric_literal_hex() {
assert_eq!(
parse_numeric_literal("0x1234").unwrap(),
Value::from_i64(4660)
);
assert_eq!(
parse_numeric_literal("0xFFFFFFFF").unwrap(),
Value::from_i64(4294967295)
);
assert_eq!(
parse_numeric_literal("0x7FFFFFFF").unwrap(),
Value::from_i64(2147483647)
);
assert_eq!(
parse_numeric_literal("0x7FFFFFFFFFFFFFFF").unwrap(),
Value::from_i64(9223372036854775807)
);
assert_eq!(
parse_numeric_literal("0xFFFFFFFFFFFFFFFF").unwrap(),
Value::from_i64(-1)
);
assert_eq!(
parse_numeric_literal("0x8000000000000000").unwrap(),
Value::from_i64(-9223372036854775808)
);
assert_eq!(
parse_numeric_literal("-0x1234").unwrap(),
Value::from_i64(-4660)
);
assert!(parse_numeric_literal("-0x8000000000000000").is_err());
}
#[test]
fn test_parse_numeric_literal_integer() {
assert_eq!(parse_numeric_literal("123").unwrap(), Value::from_i64(123));
assert_eq!(
parse_numeric_literal("9_223_372_036_854_775_807").unwrap(),
Value::from_i64(9223372036854775807)
);
}
#[test]
fn test_parse_numeric_literal_float() {
assert_eq!(
parse_numeric_literal("123.456").unwrap(),
Value::from_f64(123.456)
);
assert_eq!(
parse_numeric_literal(".123").unwrap(),
Value::from_f64(0.123)
);
assert_eq!(
parse_numeric_literal("1.23e10").unwrap(),
Value::from_f64(1.23e10)
);
assert_eq!(
parse_numeric_literal("1e-10").unwrap(),
Value::from_f64(1e-10)
);
assert_eq!(
parse_numeric_literal("1.23E+10").unwrap(),
Value::from_f64(1.23e10)
);
assert_eq!(
parse_numeric_literal("1.1_1").unwrap(),
Value::from_f64(1.11)
);
assert_eq!(
parse_numeric_literal("9223372036854775808").unwrap(),
Value::from_f64(9.223_372_036_854_776e18)
);
assert_eq!(
parse_numeric_literal("-9223372036854775809").unwrap(),
Value::from_f64(-9.223_372_036_854_776e18)
);
}
#[test]
fn test_parse_pragma_bool() {
assert!(parse_pragma_bool(&Expr::Literal(Literal::Numeric("1".into()))).unwrap(),);
assert!(parse_pragma_bool(&Expr::Name(Name::exact("true".into()))).unwrap(),);
assert!(parse_pragma_bool(&Expr::Name(Name::exact("on".into()))).unwrap(),);
assert!(parse_pragma_bool(&Expr::Name(Name::exact("yes".into()))).unwrap(),);
assert!(!parse_pragma_bool(&Expr::Literal(Literal::Numeric("0".into()))).unwrap(),);
assert!(!parse_pragma_bool(&Expr::Name(Name::exact("false".into()))).unwrap(),);
assert!(!parse_pragma_bool(&Expr::Name(Name::exact("off".into()))).unwrap(),);
assert!(!parse_pragma_bool(&Expr::Name(Name::exact("no".into()))).unwrap(),);
assert!(parse_pragma_bool(&Expr::Name(Name::exact("nono".into()))).is_err());
assert!(parse_pragma_bool(&Expr::Name(Name::exact("10".into()))).is_err());
assert!(parse_pragma_bool(&Expr::Name(Name::exact("-1".into()))).is_err());
}
#[test]
fn test_type_from_name() {
let tc = vec![
("", (SchemaValueType::Blob, false)),
("INTEGER", (SchemaValueType::Integer, true)),
("INT", (SchemaValueType::Integer, false)),
("CHAR", (SchemaValueType::Text, false)),
("CLOB", (SchemaValueType::Text, false)),
("TEXT", (SchemaValueType::Text, false)),
("BLOB", (SchemaValueType::Blob, false)),
("REAL", (SchemaValueType::Real, false)),
("FLOAT", (SchemaValueType::Real, false)),
("DOUBLE", (SchemaValueType::Real, false)),
("U128", (SchemaValueType::Numeric, false)),
];
for (input, expected) in tc {
let result = type_from_name(input);
assert_eq!(result, expected, "Failed for input: {input}");
}
}
#[test]
fn test_checked_cast_text_to_numeric_lossless_property() {
assert_eq!(checked_cast_text_to_numeric("1.xx", true), Err(()));
assert_eq!(checked_cast_text_to_numeric("abc", true), Err(()));
assert_eq!(checked_cast_text_to_numeric("--5", true), Err(()));
assert_eq!(checked_cast_text_to_numeric("12.34.56", true), Err(()));
assert_eq!(checked_cast_text_to_numeric("", true), Err(()));
assert_eq!(checked_cast_text_to_numeric(" ", true), Err(()));
assert_eq!(
checked_cast_text_to_numeric("0", true),
Ok(Value::from_i64(0))
);
assert_eq!(
checked_cast_text_to_numeric("42", true),
Ok(Value::from_i64(42))
);
assert_eq!(
checked_cast_text_to_numeric("-42", true),
Ok(Value::from_i64(-42))
);
assert_eq!(
checked_cast_text_to_numeric("999999999999", true),
Ok(Value::from_i64(999_999_999_999))
);
assert_eq!(
checked_cast_text_to_numeric("1.0", true),
Ok(Value::from_i64(1))
);
assert_eq!(
checked_cast_text_to_numeric("-3.22", true),
Ok(Value::from_f64(-3.22))
);
assert_eq!(
checked_cast_text_to_numeric("0.001", true),
Ok(Value::from_f64(0.001))
);
assert_eq!(
checked_cast_text_to_numeric("2e3", true),
Ok(Value::from_i64(2000))
);
assert_eq!(
checked_cast_text_to_numeric("-5.5e-2", true),
Ok(Value::from_f64(-0.055))
);
assert_eq!(
checked_cast_text_to_numeric(" 123 ", true),
Ok(Value::from_i64(123))
);
assert_eq!(
checked_cast_text_to_numeric("\t-3.22\n", true),
Ok(Value::from_f64(-3.22))
);
}
#[test]
fn test_trim_ascii_whitespace_helper() {
assert_eq!(trim_ascii_whitespace(" hello "), "hello");
assert_eq!(trim_ascii_whitespace("\t\nhello\r\n"), "hello");
assert_eq!(trim_ascii_whitespace("hello"), "hello");
assert_eq!(trim_ascii_whitespace(" "), "");
assert_eq!(trim_ascii_whitespace(""), "");
assert_eq!(
trim_ascii_whitespace("\u{00A0}hello\u{00A0}"),
"\u{00A0}hello\u{00A0}"
);
assert_eq!(
trim_ascii_whitespace(" \u{00A0}hello\u{00A0} "),
"\u{00A0}hello\u{00A0}"
);
}
#[test]
fn test_cast_real_to_integer_limits() {
let max_exact = ((1i64 << 51) - 1) as f64;
assert_eq!(cast_real_to_integer(max_exact), Ok((1i64 << 51) - 1));
assert_eq!(cast_real_to_integer(-max_exact), Ok(-((1i64 << 51) - 1)));
assert_eq!(cast_real_to_integer((1i64 << 51) as f64), Ok(1i64 << 51));
assert_eq!(cast_real_to_integer((1i64 << 52) as f64), Ok(1i64 << 52));
assert_eq!(cast_real_to_integer((1i64 << 62) as f64), Ok(1i64 << 62));
assert_eq!(
cast_real_to_integer(426601719749026560.0),
Ok(426601719749026560)
);
assert_eq!(cast_real_to_integer(i64::MIN as f64), Err(()));
assert_eq!(cast_real_to_integer(i64::MAX as f64), Err(()));
assert_eq!(cast_real_to_integer(9223372036854775808.0), Err(()));
assert_eq!(cast_real_to_integer(-9223372036854777856.0), Err(()));
assert_eq!(cast_real_to_integer(1.5), Err(()));
assert_eq!(cast_real_to_integer(-1.5), Err(()));
assert_eq!(cast_real_to_integer(f64::INFINITY), Err(()));
assert_eq!(cast_real_to_integer(f64::NEG_INFINITY), Err(()));
assert_eq!(cast_real_to_integer(f64::NAN), Err(()));
}
}