use crate::{Error, Result};
const AGGREGATE_PREFIXES: &[&str] = &["COUNT(", "SUM(", "AVG(", "MIN(", "MAX("];
const JOIN_KEYWORDS: &[&str] = &[
" JOIN ",
" INNER JOIN ",
" LEFT JOIN ",
" RIGHT JOIN ",
" FULL JOIN ",
" CROSS JOIN ",
];
const RANGE_OPERATORS: &[&str] = &[">=", "<=", "!=", "<>", ">", "<"];
#[derive(Debug, Clone, Copy)]
pub struct M2SelectValidator;
#[derive(Debug, Clone, PartialEq)]
pub struct SelectValidationResult {
pub has_partition_key_filter: bool,
pub has_clustering_filters: bool,
pub has_limit: bool,
pub unsupported_features: Vec<UnsupportedFeature>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnsupportedFeature {
OrderBy,
AllowFiltering,
Aggregates,
GroupBy,
Having,
Joins,
RangeQueries,
}
impl UnsupportedFeature {
fn label(self) -> &'static str {
match self {
UnsupportedFeature::OrderBy => "ORDER BY",
UnsupportedFeature::AllowFiltering => "ALLOW FILTERING",
UnsupportedFeature::Aggregates => "Aggregates (COUNT, SUM, AVG, MIN, MAX)",
UnsupportedFeature::GroupBy => "GROUP BY",
UnsupportedFeature::Having => "HAVING",
UnsupportedFeature::Joins => "JOINs",
UnsupportedFeature::RangeQueries => "Range queries (>, <, >=, <=, !=, <>)",
}
}
}
impl std::fmt::Display for UnsupportedFeature {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.label())
}
}
impl M2SelectValidator {
pub fn validate_select(&self, cql: &str) -> Result<SelectValidationResult> {
let cql_upper = cql.to_uppercase();
let where_pos = cql_upper.find("WHERE");
let rules: &[(bool, UnsupportedFeature)] = &[
(cql_upper.contains("ORDER BY"), UnsupportedFeature::OrderBy),
(
cql_upper.contains("ALLOW FILTERING"),
UnsupportedFeature::AllowFiltering,
),
(
AGGREGATE_PREFIXES.iter().any(|p| cql_upper.contains(p)),
UnsupportedFeature::Aggregates,
),
(cql_upper.contains("GROUP BY"), UnsupportedFeature::GroupBy),
(cql_upper.contains("HAVING"), UnsupportedFeature::Having),
(
JOIN_KEYWORDS.iter().any(|j| cql_upper.contains(j)),
UnsupportedFeature::Joins,
),
(
has_range_operator_after(&cql_upper, where_pos),
UnsupportedFeature::RangeQueries,
),
];
let unsupported_features: Vec<UnsupportedFeature> = rules
.iter()
.filter_map(|(hit, feat)| hit.then_some(*feat))
.collect();
if !unsupported_features.is_empty() {
return Err(unsupported_query_error(&unsupported_features));
}
let has_where = where_pos.is_some();
Ok(SelectValidationResult {
has_partition_key_filter: has_where,
has_clustering_filters: has_where && cql_upper.contains("AND"),
has_limit: cql_upper.contains("LIMIT"),
unsupported_features,
})
}
}
fn has_range_operator_after(cql_upper: &str, where_pos: Option<usize>) -> bool {
match where_pos {
Some(pos) => {
let after_where = &cql_upper[pos..];
RANGE_OPERATORS.iter().any(|op| after_where.contains(op))
}
None => false,
}
}
fn unsupported_query_error(features: &[UnsupportedFeature]) -> Error {
let feature_list = features
.iter()
.map(|f| f.label())
.collect::<Vec<_>>()
.join(", ");
Error::unsupported_query(format!(
"Unsupported query form in M2. Unsupported features: [{}]. \
M2 supports: SELECT with partition/primary key equality and optional LIMIT. \
Try narrowing your WHERE clause to use only equality (=) on partition/primary keys.",
feature_list
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_select_with_partition_key() {
let validator = M2SelectValidator;
let cql = "SELECT * FROM users WHERE user_id = 123";
let result = validator.validate_select(cql).unwrap();
assert!(result.has_partition_key_filter);
assert!(!result.has_clustering_filters);
assert!(!result.has_limit);
assert!(result.unsupported_features.is_empty());
}
#[test]
fn test_select_with_limit() {
let validator = M2SelectValidator;
let cql = "SELECT * FROM users WHERE user_id = 123 LIMIT 10";
let result = validator.validate_select(cql).unwrap();
assert!(result.has_partition_key_filter);
assert!(!result.has_clustering_filters);
assert!(result.has_limit);
assert!(result.unsupported_features.is_empty());
}
#[test]
fn test_select_with_clustering_columns() {
let validator = M2SelectValidator;
let cql = "SELECT * FROM events WHERE user_id = 123 AND timestamp = '2024-01-01'";
let result = validator.validate_select(cql).unwrap();
assert!(result.has_partition_key_filter);
assert!(result.has_clustering_filters);
assert!(!result.has_limit);
assert!(result.unsupported_features.is_empty());
}
#[test]
fn test_select_with_order_by() {
let validator = M2SelectValidator;
let cql = "SELECT * FROM users WHERE user_id = 123 ORDER BY name ASC";
let result = validator.validate_select(cql);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("ORDER BY"));
assert!(err.to_string().contains("Unsupported query form in M2"));
}
#[test]
fn test_select_with_allow_filtering() {
let validator = M2SelectValidator;
let cql = "SELECT * FROM users WHERE email = 'test@example.com' ALLOW FILTERING";
let result = validator.validate_select(cql);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("ALLOW FILTERING"));
}
#[test]
fn test_select_with_count_aggregate() {
let validator = M2SelectValidator;
let cql = "SELECT COUNT(*) FROM users WHERE user_id = 123";
let result = validator.validate_select(cql);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Aggregates"));
}
#[test]
fn test_select_with_sum_aggregate() {
let validator = M2SelectValidator;
let cql = "SELECT SUM(amount) FROM transactions WHERE user_id = 123";
let result = validator.validate_select(cql);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Aggregates"));
}
#[test]
fn test_select_with_group_by() {
let validator = M2SelectValidator;
let cql = "SELECT user_id, COUNT(*) FROM users GROUP BY user_id";
let result = validator.validate_select(cql);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("GROUP BY"));
}
#[test]
fn test_select_with_having() {
let validator = M2SelectValidator;
let cql = "SELECT user_id, COUNT(*) FROM users GROUP BY user_id HAVING COUNT(*) > 5";
let result = validator.validate_select(cql);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("HAVING"));
}
#[test]
fn test_select_with_join() {
let validator = M2SelectValidator;
let cql = "SELECT u.* FROM users u JOIN orders o ON u.user_id = o.user_id";
let result = validator.validate_select(cql);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("JOIN"));
}
#[test]
fn test_select_with_greater_than() {
let validator = M2SelectValidator;
let cql = "SELECT * FROM users WHERE age > 18";
let result = validator.validate_select(cql);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Range queries"));
}
#[test]
fn test_select_with_less_than_or_equal() {
let validator = M2SelectValidator;
let cql = "SELECT * FROM users WHERE age <= 65";
let result = validator.validate_select(cql);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Range queries"));
}
#[test]
fn test_select_with_not_equal() {
let validator = M2SelectValidator;
let cql = "SELECT * FROM users WHERE status != 'deleted'";
let result = validator.validate_select(cql);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Range queries"));
}
#[test]
fn test_select_with_not_equal_alternative() {
let validator = M2SelectValidator;
let cql = "SELECT * FROM users WHERE status <> 'deleted'";
let result = validator.validate_select(cql);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Range queries"));
}
#[test]
fn test_select_with_multiple_unsupported_features() {
let validator = M2SelectValidator;
let cql =
"SELECT COUNT(*) FROM users WHERE age > 18 GROUP BY country ORDER BY COUNT(*) DESC";
let result = validator.validate_select(cql);
assert!(result.is_err());
let err = result.unwrap_err();
let err_msg = err.to_string();
assert!(err_msg.contains("ORDER BY"));
assert!(err_msg.contains("Aggregates"));
assert!(err_msg.contains("GROUP BY"));
assert!(err_msg.contains("Range queries"));
}
#[test]
fn test_case_insensitive_detection() {
let validator = M2SelectValidator;
let cql_lower = "select * from users where user_id = 123 order by name";
let result = validator.validate_select(cql_lower);
assert!(result.is_err());
let cql_mixed = "SeLeCt * FrOm users WhErE user_id = 123 OrDeR bY name";
let result = validator.validate_select(cql_mixed);
assert!(result.is_err());
}
#[test]
fn test_unsupported_feature_display() {
assert_eq!(UnsupportedFeature::OrderBy.to_string(), "ORDER BY");
assert_eq!(
UnsupportedFeature::AllowFiltering.to_string(),
"ALLOW FILTERING"
);
assert_eq!(
UnsupportedFeature::Aggregates.to_string(),
"Aggregates (COUNT, SUM, AVG, MIN, MAX)"
);
assert_eq!(UnsupportedFeature::GroupBy.to_string(), "GROUP BY");
assert_eq!(UnsupportedFeature::Having.to_string(), "HAVING");
assert_eq!(UnsupportedFeature::Joins.to_string(), "JOINs");
assert_eq!(
UnsupportedFeature::RangeQueries.to_string(),
"Range queries (>, <, >=, <=, !=, <>)"
);
}
#[test]
fn test_validation_result_equality() {
let result1 = SelectValidationResult {
has_partition_key_filter: true,
has_clustering_filters: false,
has_limit: true,
unsupported_features: vec![],
};
let result2 = SelectValidationResult {
has_partition_key_filter: true,
has_clustering_filters: false,
has_limit: true,
unsupported_features: vec![],
};
assert_eq!(result1, result2);
}
#[test]
fn test_all_aggregate_functions() {
let validator = M2SelectValidator;
for aggregate in &["COUNT", "SUM", "AVG", "MIN", "MAX"] {
let cql = format!("SELECT {}(*) FROM users WHERE user_id = 123", aggregate);
let result = validator.validate_select(&cql);
assert!(result.is_err(), "Should detect {} aggregate", aggregate);
}
}
#[test]
fn test_all_join_types() {
let validator = M2SelectValidator;
for join_type in &[
"JOIN",
"INNER JOIN",
"LEFT JOIN",
"RIGHT JOIN",
"FULL JOIN",
"CROSS JOIN",
] {
let cql = format!(
"SELECT * FROM users {} orders ON users.id = orders.user_id",
join_type
);
let result = validator.validate_select(&cql);
assert!(result.is_err(), "Should detect {} join", join_type);
}
}
#[test]
fn test_all_range_operators() {
let validator = M2SelectValidator;
for operator in &[">", "<", ">=", "<=", "!=", "<>"] {
let cql = format!("SELECT * FROM users WHERE age {} 18", operator);
let result = validator.validate_select(&cql);
assert!(
result.is_err(),
"Should detect range operator: {}",
operator
);
}
}
#[test]
fn test_select_without_where() {
let validator = M2SelectValidator;
let cql = "SELECT * FROM users";
let result = validator.validate_select(cql).unwrap();
assert!(!result.has_partition_key_filter);
assert!(!result.has_clustering_filters);
assert!(!result.has_limit);
assert!(result.unsupported_features.is_empty());
}
#[test]
fn test_complex_valid_query() {
let validator = M2SelectValidator;
let cql = "SELECT user_id, name, email FROM users \
WHERE user_id = 123 AND status = 'active' LIMIT 100";
let result = validator.validate_select(cql).unwrap();
assert!(result.has_partition_key_filter);
assert!(result.has_clustering_filters);
assert!(result.has_limit);
assert!(result.unsupported_features.is_empty());
}
}