use datafusion_common::ScalarValue;
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
use datafusion_expr::Expr;
use datafusion_sql::unparser::{self, dialect::Dialect};
struct LanceSqlDialect;
impl Dialect for LanceSqlDialect {
fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
let needs_quote = identifier.chars().any(|c| c.is_ascii_uppercase())
|| !identifier
.chars()
.enumerate()
.all(|(i, c)| c == '_' || c.is_ascii_alphabetic() || (i > 0 && c.is_ascii_digit()));
if needs_quote { Some('`') } else { None }
}
}
const BINARY_PLACEHOLDER_PREFIX: &str = "__lancedb_binary_placeholder_";
fn bytes_to_hex_sql(bytes: &[u8]) -> String {
let hex: String = bytes.iter().map(|b| format!("{b:02X}")).collect();
format!("X'{hex}'")
}
fn has_binary_literal(expr: &Expr) -> bool {
let mut found = false;
let _ = expr.apply(&mut |e: &Expr| {
if matches!(
e,
Expr::Literal(ScalarValue::Binary(_) | ScalarValue::LargeBinary(_), _)
) {
found = true;
Ok(TreeNodeRecursion::Stop)
} else {
Ok(TreeNodeRecursion::Continue)
}
});
found
}
fn run_unparser(expr: &Expr) -> crate::Result<String> {
let ast = unparser::Unparser::new(&LanceSqlDialect)
.expr_to_sql(expr)
.map_err(|e| crate::Error::InvalidInput {
message: format!("failed to serialize expression to SQL: {}", e),
})?;
Ok(ast.to_string())
}
pub fn expr_to_sql_string(expr: &Expr) -> crate::Result<String> {
if !has_binary_literal(expr) {
return run_unparser(expr);
}
let mut bindings: Vec<Vec<u8>> = Vec::new();
let rewritten = expr
.clone()
.transform(|e: Expr| match e {
Expr::Literal(ScalarValue::Binary(Some(bytes)), m)
| Expr::Literal(ScalarValue::LargeBinary(Some(bytes)), m) => {
let placeholder = format!("{}{}__", BINARY_PLACEHOLDER_PREFIX, bindings.len());
bindings.push(bytes);
Ok(Transformed::yes(Expr::Literal(
ScalarValue::Utf8(Some(placeholder)),
m,
)))
}
Expr::Literal(ScalarValue::Binary(None), m)
| Expr::Literal(ScalarValue::LargeBinary(None), m) => {
Ok(Transformed::yes(Expr::Literal(ScalarValue::Null, m)))
}
other => Ok(Transformed::no(other)),
})
.map_err(|e| crate::Error::InvalidInput {
message: format!("failed to rewrite expression: {}", e),
})?
.data;
let mut sql = run_unparser(&rewritten)?;
for (i, bytes) in bindings.iter().enumerate() {
let quoted = format!("'{}{}__'", BINARY_PLACEHOLDER_PREFIX, i);
sql = sql.replace("ed, &bytes_to_hex_sql(bytes));
}
Ok(sql)
}