use crate::common::SmartString;
use crate::core::{DataType, Error, Result, Value};
use crate::functions::{
FunctionDataType, FunctionInfo, FunctionSignature, FunctionType, ScalarFunction,
};
use crate::validate_arg_count;
#[derive(Default)]
pub struct CastFunction;
impl ScalarFunction for CastFunction {
fn name(&self) -> &str {
"CAST"
}
fn info(&self) -> FunctionInfo {
FunctionInfo::new(
"CAST",
FunctionType::Scalar,
"Converts a value from one data type to another",
FunctionSignature::new(
FunctionDataType::Any,
vec![FunctionDataType::Any, FunctionDataType::String],
2,
2,
),
)
}
fn evaluate(&self, args: &[Value]) -> Result<Value> {
validate_arg_count!(args, "CAST", 2);
let value = &args[0];
let target_type = match &args[1] {
Value::Text(s) => s.to_uppercase(),
_ => {
return Err(Error::invalid_argument(
"Second argument to CAST must be a string type name",
))
}
};
if value.is_null() {
return Ok(match target_type.as_str() {
"INT" | "INTEGER" => Value::Null(DataType::Integer),
"FLOAT" | "REAL" | "DOUBLE" => Value::Null(DataType::Float),
"STRING" | "TEXT" | "VARCHAR" | "CHAR" => Value::Null(DataType::Text),
"BOOLEAN" | "BOOL" => Value::Null(DataType::Boolean),
"TIMESTAMP" | "DATETIME" | "DATE" | "TIME" => Value::Null(DataType::Timestamp),
"JSON" => Value::Null(DataType::Json),
_ => Value::null_unknown(),
});
}
match target_type.as_str() {
"INT" | "INTEGER" => cast_to_integer(value),
"FLOAT" | "REAL" | "DOUBLE" => cast_to_float(value),
"STRING" | "TEXT" | "VARCHAR" | "CHAR" => cast_to_string(value),
"BOOLEAN" | "BOOL" => cast_to_boolean(value),
"TIMESTAMP" | "DATETIME" | "DATE" | "TIME" => cast_to_timestamp(value),
"JSON" => cast_to_json(value),
_ => Err(Error::invalid_argument(format!(
"Unsupported cast target type: {}",
target_type
))),
}
}
fn clone_box(&self) -> Box<dyn ScalarFunction> {
Box::new(CastFunction)
}
}
fn cast_to_integer(value: &Value) -> Result<Value> {
match value {
Value::Integer(i) => Ok(Value::Integer(*i)),
Value::Float(f) => {
if !f.is_finite() {
return Err(Error::invalid_argument(format!(
"Cannot cast {} to INTEGER",
f
)));
}
if *f > i64::MAX as f64 || *f < i64::MIN as f64 {
return Err(Error::invalid_argument(format!(
"Float value {} out of INTEGER range",
f
)));
}
Ok(Value::Integer(*f as i64))
}
Value::Boolean(b) => Ok(Value::Integer(if *b { 1 } else { 0 })),
Value::Text(s) => {
if s.is_empty() {
return Ok(Value::Integer(0));
}
if let Ok(i) = s.parse::<i64>() {
return Ok(Value::Integer(i));
}
if let Ok(f) = s.parse::<f64>() {
if !f.is_finite() || f > i64::MAX as f64 || f < i64::MIN as f64 {
return Err(Error::invalid_argument(format!(
"Cannot convert '{}' to INTEGER",
s
)));
}
return Ok(Value::Integer(f as i64));
}
Err(Error::invalid_argument(format!(
"Cannot convert '{}' to INTEGER",
s
)))
}
Value::Timestamp(t) => Ok(Value::Integer(t.timestamp())),
Value::Extension(_) => Ok(Value::Integer(0)),
Value::Null(dt) => Ok(Value::Null(*dt)),
}
}
fn cast_to_float(value: &Value) -> Result<Value> {
match value {
Value::Integer(i) => Ok(Value::Float(*i as f64)),
Value::Float(f) => Ok(Value::Float(*f)),
Value::Boolean(b) => Ok(Value::Float(if *b { 1.0 } else { 0.0 })),
Value::Text(s) => {
if s.is_empty() {
return Ok(Value::Float(0.0));
}
match s.parse::<f64>() {
Ok(f) => Ok(Value::Float(f)),
Err(_) => Err(Error::invalid_argument(format!(
"Cannot convert '{}' to FLOAT",
s
))),
}
}
Value::Timestamp(t) => Ok(Value::Float(t.timestamp() as f64)),
Value::Extension(_) => Err(Error::invalid_argument("Cannot convert JSON to FLOAT")),
Value::Null(dt) => Ok(Value::Null(*dt)),
}
}
fn cast_to_string(value: &Value) -> Result<Value> {
match value {
Value::Text(s) => Ok(Value::Text(s.clone())),
Value::Integer(i) => Ok(Value::Text(SmartString::from_string(i.to_string()))),
Value::Float(f) => {
Ok(Value::Text(SmartString::from_string(format!("{:.6}", f))))
}
Value::Boolean(b) => Ok(Value::Text(SmartString::from_string(b.to_string()))),
Value::Timestamp(t) => Ok(Value::Text(SmartString::from_string(t.to_rfc3339()))),
Value::Extension(data) if data.first() == Some(&(DataType::Json as u8)) => {
let s = std::str::from_utf8(&data[1..]).unwrap_or("");
Ok(Value::Text(SmartString::from(s)))
}
Value::Extension(_) => Ok(Value::Text(SmartString::from(""))),
Value::Null(dt) => Ok(Value::Null(*dt)),
}
}
fn cast_to_boolean(value: &Value) -> Result<Value> {
match value {
Value::Boolean(b) => Ok(Value::Boolean(*b)),
Value::Integer(i) => Ok(Value::Boolean(*i != 0)),
Value::Float(f) => Ok(Value::Boolean(*f != 0.0)),
Value::Text(s) => {
let lower = s.to_lowercase();
let is_true = lower == "true"
|| lower == "yes"
|| lower == "1"
|| (!lower.is_empty() && lower != "0" && lower != "false" && lower != "no");
Ok(Value::Boolean(is_true))
}
Value::Timestamp(_) => Err(Error::invalid_argument(
"Cannot convert TIMESTAMP to BOOLEAN",
)),
Value::Extension(_) => Err(Error::invalid_argument("Cannot convert JSON to BOOLEAN")),
Value::Null(dt) => Ok(Value::Null(*dt)),
}
}
fn cast_to_timestamp(value: &Value) -> Result<Value> {
match value {
Value::Timestamp(t) => Ok(Value::Timestamp(*t)),
Value::Text(s) => {
match crate::core::parse_timestamp(s) {
Ok(t) => Ok(Value::Timestamp(t)),
Err(_) => Err(Error::invalid_argument(format!(
"Cannot parse '{}' as TIMESTAMP",
s
))),
}
}
Value::Integer(i) => {
use chrono::{TimeZone, Utc};
match Utc.timestamp_opt(*i, 0) {
chrono::LocalResult::Single(t) => Ok(Value::Timestamp(t)),
_ => Err(Error::invalid_argument(format!(
"Invalid Unix timestamp: {}",
i
))),
}
}
_ => Err(Error::invalid_argument(format!(
"Cannot convert {:?} to TIMESTAMP",
value.data_type()
))),
}
}
fn cast_to_json(value: &Value) -> Result<Value> {
match value {
Value::Extension(data) if data.first() == Some(&(DataType::Json as u8)) => {
Ok(value.clone())
}
Value::Text(s) => Ok(Value::json(s.as_ref())),
Value::Integer(i) => Ok(Value::json(i.to_string())),
Value::Float(f) => Ok(Value::json(f.to_string())),
Value::Boolean(b) => Ok(Value::json(b.to_string())),
Value::Null(_) => Ok(Value::json("null")),
Value::Timestamp(t) => Ok(Value::json(format!("\"{}\"", t.to_rfc3339()))),
Value::Extension(_) => Ok(Value::json("null")),
}
}
#[derive(Default)]
pub struct CollateFunction;
impl ScalarFunction for CollateFunction {
fn name(&self) -> &str {
"COLLATE"
}
fn info(&self) -> FunctionInfo {
FunctionInfo::new(
"COLLATE",
FunctionType::Scalar,
"Applies a collation to a string value for sorting and comparison",
FunctionSignature::new(
FunctionDataType::String,
vec![FunctionDataType::Any, FunctionDataType::String],
2,
2,
),
)
}
fn evaluate(&self, args: &[Value]) -> Result<Value> {
validate_arg_count!(args, "COLLATE", 2);
if args[0].is_null() {
return Ok(Value::null_unknown());
}
let s = match &args[0] {
Value::Text(s) => s.to_string(),
Value::Integer(i) => i.to_string(),
Value::Float(f) => f.to_string(),
Value::Boolean(b) => b.to_string(),
Value::Timestamp(t) => t.to_rfc3339(),
Value::Extension(data) if data.first() == Some(&(DataType::Json as u8)) => {
std::str::from_utf8(&data[1..]).unwrap_or("").to_string()
}
Value::Extension(_) => String::new(),
Value::Null(_) => return Ok(Value::null_unknown()),
};
let collation = match &args[1] {
Value::Text(c) => c.to_uppercase(),
_ => {
return Err(Error::invalid_argument(
"COLLATE requires a string as the second argument",
))
}
};
let result = apply_collation(&s, &collation)?;
Ok(Value::Text(SmartString::from_string(result)))
}
fn clone_box(&self) -> Box<dyn ScalarFunction> {
Box::new(CollateFunction)
}
}
fn apply_collation(s: &str, collation: &str) -> Result<String> {
match collation {
"BINARY" => Ok(s.to_string()),
"NOCASE" | "CASE_INSENSITIVE" => Ok(s.to_lowercase()),
"NOACCENT" | "ACCENT_INSENSITIVE" => Ok(remove_accents(s)),
"NUMERIC" => Ok(s.to_string()), _ => Err(Error::invalid_argument(format!(
"Unsupported collation: {}",
collation
))),
}
}
fn remove_accents(s: &str) -> String {
s.chars()
.filter_map(|c| {
Some(match c {
'À'..='Å' => 'A',
'à'..='å' => 'a',
'È'..='Ë' => 'E',
'è'..='ë' => 'e',
'Ì'..='Ï' => 'I',
'ì'..='ï' => 'i',
'Ò'..='Ö' => 'O',
'ò'..='ö' => 'o',
'Ù'..='Ü' => 'U',
'ù'..='ü' => 'u',
'Ç' => 'C',
'ç' => 'c',
'Ñ' => 'N',
'ñ' => 'n',
'Ÿ' => 'Y',
'ÿ' => 'y',
_ if c.is_ascii() || !is_combining_mark(c) => c,
_ => return None,
})
})
.collect()
}
fn is_combining_mark(c: char) -> bool {
matches!(c, '\u{0300}'..='\u{036F}')
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cast_to_integer() {
let cast = CastFunction;
assert_eq!(
cast.evaluate(&[Value::Integer(42), Value::text("INTEGER")])
.unwrap(),
Value::Integer(42)
);
assert_eq!(
cast.evaluate(&[Value::Float(3.7), Value::text("INT")])
.unwrap(),
Value::Integer(3)
);
assert_eq!(
cast.evaluate(&[Value::text("123"), Value::text("INTEGER")])
.unwrap(),
Value::Integer(123)
);
assert_eq!(
cast.evaluate(&[Value::Boolean(true), Value::text("INT")])
.unwrap(),
Value::Integer(1)
);
assert_eq!(
cast.evaluate(&[Value::Boolean(false), Value::text("INT")])
.unwrap(),
Value::Integer(0)
);
assert_eq!(
cast.evaluate(&[Value::text(""), Value::text("INTEGER")])
.unwrap(),
Value::Integer(0)
);
}
#[test]
fn test_cast_text_to_integer_edge_cases() {
let cast = CastFunction;
assert!(cast
.evaluate(&[Value::text("inf"), Value::text("INTEGER")])
.is_err());
assert!(cast
.evaluate(&[Value::text("-inf"), Value::text("INTEGER")])
.is_err());
assert!(cast
.evaluate(&[Value::text("NaN"), Value::text("INTEGER")])
.is_err());
assert!(cast
.evaluate(&[Value::text("1e30"), Value::text("INTEGER")])
.is_err());
assert_eq!(
cast.evaluate(&[Value::text("3.14"), Value::text("INTEGER")])
.unwrap(),
Value::Integer(3)
);
}
#[test]
fn test_cast_to_float() {
let cast = CastFunction;
assert_eq!(
cast.evaluate(&[Value::Integer(42), Value::text("FLOAT")])
.unwrap(),
Value::Float(42.0)
);
assert_eq!(
cast.evaluate(&[Value::Float(3.5), Value::text("REAL")])
.unwrap(),
Value::Float(3.5)
);
assert_eq!(
cast.evaluate(&[Value::text("2.5"), Value::text("DOUBLE")])
.unwrap(),
Value::Float(2.5)
);
}
#[test]
fn test_cast_to_string() {
let cast = CastFunction;
assert_eq!(
cast.evaluate(&[Value::Integer(42), Value::text("TEXT")])
.unwrap(),
Value::text("42")
);
assert_eq!(
cast.evaluate(&[Value::Boolean(true), Value::text("STRING")])
.unwrap(),
Value::text("true")
);
assert_eq!(
cast.evaluate(&[Value::text("hello"), Value::text("VARCHAR")])
.unwrap(),
Value::text("hello")
);
}
#[test]
fn test_cast_to_boolean() {
let cast = CastFunction;
assert_eq!(
cast.evaluate(&[Value::Integer(1), Value::text("BOOL")])
.unwrap(),
Value::Boolean(true)
);
assert_eq!(
cast.evaluate(&[Value::Integer(0), Value::text("BOOLEAN")])
.unwrap(),
Value::Boolean(false)
);
assert_eq!(
cast.evaluate(&[Value::text("true"), Value::text("BOOL")])
.unwrap(),
Value::Boolean(true)
);
assert_eq!(
cast.evaluate(&[Value::text("false"), Value::text("BOOL")])
.unwrap(),
Value::Boolean(false)
);
assert_eq!(
cast.evaluate(&[Value::text("yes"), Value::text("BOOL")])
.unwrap(),
Value::Boolean(true)
);
}
#[test]
fn test_cast_null_handling() {
let cast = CastFunction;
let result = cast
.evaluate(&[Value::null_unknown(), Value::text("INTEGER")])
.unwrap();
assert!(result.is_null(), "CAST(NULL AS INTEGER) should return NULL");
let result = cast
.evaluate(&[Value::null_unknown(), Value::text("TEXT")])
.unwrap();
assert!(result.is_null(), "CAST(NULL AS TEXT) should return NULL");
let result = cast
.evaluate(&[Value::null_unknown(), Value::text("FLOAT")])
.unwrap();
assert!(result.is_null(), "CAST(NULL AS FLOAT) should return NULL");
let result = cast
.evaluate(&[Value::null_unknown(), Value::text("BOOLEAN")])
.unwrap();
assert!(result.is_null(), "CAST(NULL AS BOOLEAN) should return NULL");
}
#[test]
fn test_collate_binary() {
let collate = CollateFunction;
assert_eq!(
collate
.evaluate(&[Value::text("Hello"), Value::text("BINARY")])
.unwrap(),
Value::text("Hello")
);
}
#[test]
fn test_collate_nocase() {
let collate = CollateFunction;
assert_eq!(
collate
.evaluate(&[Value::text("HELLO"), Value::text("NOCASE")])
.unwrap(),
Value::text("hello")
);
assert_eq!(
collate
.evaluate(&[Value::text("Hello World"), Value::text("CASE_INSENSITIVE")])
.unwrap(),
Value::text("hello world")
);
}
#[test]
fn test_collate_noaccent() {
let collate = CollateFunction;
assert_eq!(
collate
.evaluate(&[Value::text("Café"), Value::text("NOACCENT")])
.unwrap(),
Value::text("Cafe")
);
assert_eq!(
collate
.evaluate(&[Value::text("Naïve"), Value::text("ACCENT_INSENSITIVE")])
.unwrap(),
Value::text("Naive")
);
}
#[test]
fn test_collate_null_handling() {
let collate = CollateFunction;
let result = collate
.evaluate(&[Value::null_unknown(), Value::text("NOCASE")])
.unwrap();
assert!(result.is_null());
}
#[test]
fn test_collate_unsupported() {
let collate = CollateFunction;
let result = collate.evaluate(&[Value::text("test"), Value::text("INVALID")]);
assert!(result.is_err());
}
#[test]
fn test_remove_accents() {
assert_eq!(remove_accents("Café"), "Cafe");
assert_eq!(remove_accents("Naïve"), "Naive");
assert_eq!(remove_accents("Résumé"), "Resume");
assert_eq!(remove_accents("Élève"), "Eleve");
assert_eq!(remove_accents("Über"), "Uber");
assert_eq!(remove_accents("Español"), "Espanol");
}
}