use super::Rule;
use crate::velesql::ast::{CompareOp, Value};
use crate::velesql::error::ParseError;
pub(crate) fn compare_op_from_str(op: &str) -> Result<CompareOp, ParseError> {
match op {
"=" => Ok(CompareOp::Eq),
"!=" | "<>" => Ok(CompareOp::NotEq),
">" => Ok(CompareOp::Gt),
">=" => Ok(CompareOp::Gte),
"<" => Ok(CompareOp::Lt),
"<=" => Ok(CompareOp::Lte),
_ => Err(ParseError::syntax(0, op, "Invalid comparison operator")),
}
}
pub(crate) fn parse_value_from_str(input: &str) -> Result<Value, ParseError> {
if input.len() >= 2 && input.starts_with('\'') && input.ends_with('\'') {
return Ok(Value::String(unescape_string_literal(input)));
}
if input.eq_ignore_ascii_case("true") {
return Ok(Value::Boolean(true));
}
if input.eq_ignore_ascii_case("false") {
return Ok(Value::Boolean(false));
}
if input.eq_ignore_ascii_case("null") {
return Ok(Value::Null);
}
parse_numeric_value(input)
}
fn try_parse_integer(s: &str) -> Option<Value> {
s.parse::<i64>()
.map(Value::Integer)
.ok()
.or_else(|| s.parse::<u64>().map(Value::UnsignedInteger).ok())
}
fn parse_numeric_value(input: &str) -> Result<Value, ParseError> {
if let Some(int_val) = try_parse_integer(input) {
return Ok(int_val);
}
if let Ok(f) = input.parse::<f64>() {
return Ok(Value::Float(f));
}
Err(ParseError::syntax(
0,
input,
format!("Invalid value: {input}"),
))
}
pub(crate) fn parse_scalar_from_rule(
pair: &pest::iterators::Pair<'_, Rule>,
) -> Result<Value, ParseError> {
match pair.as_rule() {
Rule::integer => parse_integer_literal(pair.as_str()),
Rule::float => parse_float_literal(pair.as_str()),
Rule::string => Ok(Value::String(unescape_string_literal(pair.as_str()))),
Rule::boolean => Ok(Value::Boolean(pair.as_str().eq_ignore_ascii_case("true"))),
Rule::null_value => Ok(Value::Null),
Rule::parameter => Ok(parse_parameter_value(pair.as_str())),
_ => Err(ParseError::syntax(0, pair.as_str(), "Unknown value type")),
}
}
fn parse_parameter_value(raw: &str) -> Value {
Value::Parameter(raw.trim_start_matches('$').to_string())
}
fn parse_integer_literal(s: &str) -> Result<Value, ParseError> {
try_parse_integer(s).ok_or_else(|| ParseError::syntax(0, s, "Invalid integer"))
}
fn parse_float_literal(s: &str) -> Result<Value, ParseError> {
s.parse::<f64>()
.map(Value::Float)
.map_err(|_| ParseError::syntax(0, s, "Invalid float"))
}
pub(crate) fn parse_u64_clause(
pair: pest::iterators::Pair<'_, Rule>,
clause_name: &str,
) -> Result<u64, ParseError> {
let int_pair = pair
.into_inner()
.next()
.ok_or_else(|| ParseError::syntax(0, "", format!("Expected integer for {clause_name}")))?;
int_pair.as_str().parse::<u64>().map_err(|_| {
ParseError::syntax(0, int_pair.as_str(), format!("Invalid {clause_name} value"))
})
}
pub(crate) fn unescape_string_literal(raw: &str) -> String {
raw[1..raw.len() - 1].replace("''", "'")
}
pub(crate) fn extract_key_value_list<T>(
list_pair: pest::iterators::Pair<'_, super::Rule>,
item_rule: super::Rule,
extractor: impl Fn(pest::iterators::Pair<'_, super::Rule>) -> Result<T, ParseError>,
) -> Result<Vec<T>, ParseError> {
list_pair
.into_inner()
.filter(|p| p.as_rule() == item_rule)
.map(extractor)
.collect()
}
pub(crate) fn strip_identifier_quotes(s: &str) -> String {
let s = s.trim();
if s.starts_with('`') && s.ends_with('`') && s.len() >= 2 {
s[1..s.len() - 1].to_string()
} else if s.starts_with('"') && s.ends_with('"') && s.len() >= 2 {
s[1..s.len() - 1].replace("\"\"", "\"")
} else {
s.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compare_op_from_str_all_operators() {
assert_eq!(compare_op_from_str("=").unwrap(), CompareOp::Eq);
assert_eq!(compare_op_from_str("!=").unwrap(), CompareOp::NotEq);
assert_eq!(compare_op_from_str("<>").unwrap(), CompareOp::NotEq);
assert_eq!(compare_op_from_str(">").unwrap(), CompareOp::Gt);
assert_eq!(compare_op_from_str(">=").unwrap(), CompareOp::Gte);
assert_eq!(compare_op_from_str("<").unwrap(), CompareOp::Lt);
assert_eq!(compare_op_from_str("<=").unwrap(), CompareOp::Lte);
}
#[test]
fn test_compare_op_from_str_invalid() {
assert!(compare_op_from_str("??").is_err());
}
#[test]
fn test_parse_value_from_str_integer() {
assert_eq!(parse_value_from_str("42").unwrap(), Value::Integer(42));
}
#[test]
fn test_parse_value_from_str_float() {
assert_eq!(parse_value_from_str("2.72").unwrap(), Value::Float(2.72));
}
#[test]
fn test_parse_value_from_str_string() {
assert_eq!(
parse_value_from_str("'hello'").unwrap(),
Value::String("hello".to_string())
);
}
#[test]
fn test_parse_value_from_str_boolean() {
assert_eq!(parse_value_from_str("true").unwrap(), Value::Boolean(true));
assert_eq!(
parse_value_from_str("FALSE").unwrap(),
Value::Boolean(false)
);
}
#[test]
fn test_parse_value_from_str_null() {
assert_eq!(parse_value_from_str("null").unwrap(), Value::Null);
}
#[test]
fn test_parse_value_from_str_invalid() {
assert!(parse_value_from_str("not_a_value").is_err());
}
#[test]
fn test_parse_u64_clause_error_message() {
let msg = format!("Expected integer for {}", "LIMIT");
assert!(msg.contains("LIMIT"));
}
#[test]
fn test_strip_identifier_quotes_backtick() {
assert_eq!(strip_identifier_quotes("`name`"), "name");
}
#[test]
fn test_strip_identifier_quotes_double() {
assert_eq!(strip_identifier_quotes("\"col\""), "col");
}
#[test]
fn test_strip_identifier_quotes_escaped_double() {
assert_eq!(strip_identifier_quotes("\"col\"\"name\""), "col\"name");
}
#[test]
fn test_strip_identifier_quotes_plain() {
assert_eq!(strip_identifier_quotes("plain"), "plain");
}
#[test]
fn test_strip_identifier_quotes_trimmed() {
assert_eq!(strip_identifier_quotes(" `spaced` "), "spaced");
}
#[test]
fn test_unescape_string_literal_simple() {
assert_eq!(unescape_string_literal("'hello'"), "hello");
}
#[test]
fn test_unescape_string_literal_escaped_quote() {
assert_eq!(unescape_string_literal("'O''Brien'"), "O'Brien");
}
#[test]
fn test_unescape_string_literal_multiple_escapes() {
assert_eq!(
unescape_string_literal("'It''s a ''test'''"),
"It's a 'test'"
);
}
#[test]
fn test_unescape_string_literal_empty() {
assert_eq!(unescape_string_literal("''"), "");
}
#[test]
fn test_parse_integer_i64_max() {
let result = parse_value_from_str("9223372036854775807").unwrap();
assert_eq!(result, Value::Integer(i64::MAX));
}
#[test]
fn test_parse_integer_u64_value() {
let result = parse_value_from_str("9223372036854775808").unwrap();
assert_eq!(result, Value::UnsignedInteger(9_223_372_036_854_775_808));
}
#[test]
fn test_parse_integer_u64_max() {
let result = parse_value_from_str("18446744073709551615").unwrap();
assert_eq!(result, Value::UnsignedInteger(u64::MAX));
}
#[test]
fn test_parse_integer_overflow_to_float() {
let result = parse_value_from_str("18446744073709551616").unwrap();
assert!(matches!(result, Value::Float(_)));
}
}