1use crate::{Error, Result};
41
42const AGGREGATE_PREFIXES: &[&str] = &["COUNT(", "SUM(", "AVG(", "MIN(", "MAX("];
44
45const JOIN_KEYWORDS: &[&str] = &[
48 " JOIN ",
49 " INNER JOIN ",
50 " LEFT JOIN ",
51 " RIGHT JOIN ",
52 " FULL JOIN ",
53 " CROSS JOIN ",
54];
55
56const RANGE_OPERATORS: &[&str] = &[">=", "<=", "!=", "<>", ">", "<"];
58
59#[derive(Debug, Clone, Copy)]
64pub struct M2SelectValidator;
65
66#[derive(Debug, Clone, PartialEq)]
68pub struct SelectValidationResult {
69 pub has_partition_key_filter: bool,
71 pub has_clustering_filters: bool,
73 pub has_limit: bool,
75 pub unsupported_features: Vec<UnsupportedFeature>,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum UnsupportedFeature {
82 OrderBy,
84 AllowFiltering,
86 Aggregates,
88 GroupBy,
90 Having,
92 Joins,
94 RangeQueries,
96}
97
98impl UnsupportedFeature {
99 fn label(self) -> &'static str {
100 match self {
101 UnsupportedFeature::OrderBy => "ORDER BY",
102 UnsupportedFeature::AllowFiltering => "ALLOW FILTERING",
103 UnsupportedFeature::Aggregates => "Aggregates (COUNT, SUM, AVG, MIN, MAX)",
104 UnsupportedFeature::GroupBy => "GROUP BY",
105 UnsupportedFeature::Having => "HAVING",
106 UnsupportedFeature::Joins => "JOINs",
107 UnsupportedFeature::RangeQueries => "Range queries (>, <, >=, <=, !=, <>)",
108 }
109 }
110}
111
112impl std::fmt::Display for UnsupportedFeature {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.write_str(self.label())
115 }
116}
117
118impl M2SelectValidator {
119 pub fn validate_select(&self, cql: &str) -> Result<SelectValidationResult> {
135 let cql_upper = cql.to_uppercase();
136 let where_pos = cql_upper.find("WHERE");
137
138 let rules: &[(bool, UnsupportedFeature)] = &[
141 (cql_upper.contains("ORDER BY"), UnsupportedFeature::OrderBy),
142 (
143 cql_upper.contains("ALLOW FILTERING"),
144 UnsupportedFeature::AllowFiltering,
145 ),
146 (
147 AGGREGATE_PREFIXES.iter().any(|p| cql_upper.contains(p)),
148 UnsupportedFeature::Aggregates,
149 ),
150 (cql_upper.contains("GROUP BY"), UnsupportedFeature::GroupBy),
151 (cql_upper.contains("HAVING"), UnsupportedFeature::Having),
152 (
153 JOIN_KEYWORDS.iter().any(|j| cql_upper.contains(j)),
154 UnsupportedFeature::Joins,
155 ),
156 (
157 has_range_operator_after(&cql_upper, where_pos),
158 UnsupportedFeature::RangeQueries,
159 ),
160 ];
161
162 let unsupported_features: Vec<UnsupportedFeature> = rules
163 .iter()
164 .filter_map(|(hit, feat)| hit.then_some(*feat))
165 .collect();
166
167 if !unsupported_features.is_empty() {
168 return Err(unsupported_query_error(&unsupported_features));
169 }
170
171 let has_where = where_pos.is_some();
172 Ok(SelectValidationResult {
173 has_partition_key_filter: has_where,
174 has_clustering_filters: has_where && cql_upper.contains("AND"),
175 has_limit: cql_upper.contains("LIMIT"),
176 unsupported_features,
177 })
178 }
179}
180
181fn has_range_operator_after(cql_upper: &str, where_pos: Option<usize>) -> bool {
187 match where_pos {
188 Some(pos) => {
189 let after_where = &cql_upper[pos..];
190 RANGE_OPERATORS.iter().any(|op| after_where.contains(op))
191 }
192 None => false,
193 }
194}
195
196fn unsupported_query_error(features: &[UnsupportedFeature]) -> Error {
197 let feature_list = features
198 .iter()
199 .map(|f| f.label())
200 .collect::<Vec<_>>()
201 .join(", ");
202
203 Error::unsupported_query(format!(
204 "Unsupported query form in M2. Unsupported features: [{}]. \
205 M2 supports: SELECT with partition/primary key equality and optional LIMIT. \
206 Try narrowing your WHERE clause to use only equality (=) on partition/primary keys.",
207 feature_list
208 ))
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn test_simple_select_with_partition_key() {
217 let validator = M2SelectValidator;
218 let cql = "SELECT * FROM users WHERE user_id = 123";
219
220 let result = validator.validate_select(cql).unwrap();
221
222 assert!(result.has_partition_key_filter);
223 assert!(!result.has_clustering_filters);
224 assert!(!result.has_limit);
225 assert!(result.unsupported_features.is_empty());
226 }
227
228 #[test]
229 fn test_select_with_limit() {
230 let validator = M2SelectValidator;
231 let cql = "SELECT * FROM users WHERE user_id = 123 LIMIT 10";
232
233 let result = validator.validate_select(cql).unwrap();
234
235 assert!(result.has_partition_key_filter);
236 assert!(!result.has_clustering_filters);
237 assert!(result.has_limit);
238 assert!(result.unsupported_features.is_empty());
239 }
240
241 #[test]
242 fn test_select_with_clustering_columns() {
243 let validator = M2SelectValidator;
244 let cql = "SELECT * FROM events WHERE user_id = 123 AND timestamp = '2024-01-01'";
245
246 let result = validator.validate_select(cql).unwrap();
247
248 assert!(result.has_partition_key_filter);
249 assert!(result.has_clustering_filters);
250 assert!(!result.has_limit);
251 assert!(result.unsupported_features.is_empty());
252 }
253
254 #[test]
255 fn test_select_with_order_by() {
256 let validator = M2SelectValidator;
257 let cql = "SELECT * FROM users WHERE user_id = 123 ORDER BY name ASC";
258
259 let result = validator.validate_select(cql);
260
261 assert!(result.is_err());
262 let err = result.unwrap_err();
263 assert!(err.to_string().contains("ORDER BY"));
264 assert!(err.to_string().contains("Unsupported query form in M2"));
265 }
266
267 #[test]
268 fn test_select_with_allow_filtering() {
269 let validator = M2SelectValidator;
270 let cql = "SELECT * FROM users WHERE email = 'test@example.com' ALLOW FILTERING";
271
272 let result = validator.validate_select(cql);
273
274 assert!(result.is_err());
275 let err = result.unwrap_err();
276 assert!(err.to_string().contains("ALLOW FILTERING"));
277 }
278
279 #[test]
280 fn test_select_with_count_aggregate() {
281 let validator = M2SelectValidator;
282 let cql = "SELECT COUNT(*) FROM users WHERE user_id = 123";
283
284 let result = validator.validate_select(cql);
285
286 assert!(result.is_err());
287 let err = result.unwrap_err();
288 assert!(err.to_string().contains("Aggregates"));
289 }
290
291 #[test]
292 fn test_select_with_sum_aggregate() {
293 let validator = M2SelectValidator;
294 let cql = "SELECT SUM(amount) FROM transactions WHERE user_id = 123";
295
296 let result = validator.validate_select(cql);
297
298 assert!(result.is_err());
299 let err = result.unwrap_err();
300 assert!(err.to_string().contains("Aggregates"));
301 }
302
303 #[test]
304 fn test_select_with_group_by() {
305 let validator = M2SelectValidator;
306 let cql = "SELECT user_id, COUNT(*) FROM users GROUP BY user_id";
307
308 let result = validator.validate_select(cql);
309
310 assert!(result.is_err());
311 let err = result.unwrap_err();
312 assert!(err.to_string().contains("GROUP BY"));
313 }
314
315 #[test]
316 fn test_select_with_having() {
317 let validator = M2SelectValidator;
318 let cql = "SELECT user_id, COUNT(*) FROM users GROUP BY user_id HAVING COUNT(*) > 5";
319
320 let result = validator.validate_select(cql);
321
322 assert!(result.is_err());
323 let err = result.unwrap_err();
324 assert!(err.to_string().contains("HAVING"));
325 }
326
327 #[test]
328 fn test_select_with_join() {
329 let validator = M2SelectValidator;
330 let cql = "SELECT u.* FROM users u JOIN orders o ON u.user_id = o.user_id";
331
332 let result = validator.validate_select(cql);
333
334 assert!(result.is_err());
335 let err = result.unwrap_err();
336 assert!(err.to_string().contains("JOIN"));
337 }
338
339 #[test]
340 fn test_select_with_greater_than() {
341 let validator = M2SelectValidator;
342 let cql = "SELECT * FROM users WHERE age > 18";
343
344 let result = validator.validate_select(cql);
345
346 assert!(result.is_err());
347 let err = result.unwrap_err();
348 assert!(err.to_string().contains("Range queries"));
349 }
350
351 #[test]
352 fn test_select_with_less_than_or_equal() {
353 let validator = M2SelectValidator;
354 let cql = "SELECT * FROM users WHERE age <= 65";
355
356 let result = validator.validate_select(cql);
357
358 assert!(result.is_err());
359 let err = result.unwrap_err();
360 assert!(err.to_string().contains("Range queries"));
361 }
362
363 #[test]
364 fn test_select_with_not_equal() {
365 let validator = M2SelectValidator;
366 let cql = "SELECT * FROM users WHERE status != 'deleted'";
367
368 let result = validator.validate_select(cql);
369
370 assert!(result.is_err());
371 let err = result.unwrap_err();
372 assert!(err.to_string().contains("Range queries"));
373 }
374
375 #[test]
376 fn test_select_with_not_equal_alternative() {
377 let validator = M2SelectValidator;
378 let cql = "SELECT * FROM users WHERE status <> 'deleted'";
379
380 let result = validator.validate_select(cql);
381
382 assert!(result.is_err());
383 let err = result.unwrap_err();
384 assert!(err.to_string().contains("Range queries"));
385 }
386
387 #[test]
388 fn test_select_with_multiple_unsupported_features() {
389 let validator = M2SelectValidator;
390 let cql =
391 "SELECT COUNT(*) FROM users WHERE age > 18 GROUP BY country ORDER BY COUNT(*) DESC";
392
393 let result = validator.validate_select(cql);
394
395 assert!(result.is_err());
396 let err = result.unwrap_err();
397 let err_msg = err.to_string();
398
399 assert!(err_msg.contains("ORDER BY"));
401 assert!(err_msg.contains("Aggregates"));
402 assert!(err_msg.contains("GROUP BY"));
403 assert!(err_msg.contains("Range queries"));
404 }
405
406 #[test]
407 fn test_case_insensitive_detection() {
408 let validator = M2SelectValidator;
409
410 let cql_lower = "select * from users where user_id = 123 order by name";
412 let result = validator.validate_select(cql_lower);
413 assert!(result.is_err());
414
415 let cql_mixed = "SeLeCt * FrOm users WhErE user_id = 123 OrDeR bY name";
417 let result = validator.validate_select(cql_mixed);
418 assert!(result.is_err());
419 }
420
421 #[test]
422 fn test_unsupported_feature_display() {
423 assert_eq!(UnsupportedFeature::OrderBy.to_string(), "ORDER BY");
424 assert_eq!(
425 UnsupportedFeature::AllowFiltering.to_string(),
426 "ALLOW FILTERING"
427 );
428 assert_eq!(
429 UnsupportedFeature::Aggregates.to_string(),
430 "Aggregates (COUNT, SUM, AVG, MIN, MAX)"
431 );
432 assert_eq!(UnsupportedFeature::GroupBy.to_string(), "GROUP BY");
433 assert_eq!(UnsupportedFeature::Having.to_string(), "HAVING");
434 assert_eq!(UnsupportedFeature::Joins.to_string(), "JOINs");
435 assert_eq!(
436 UnsupportedFeature::RangeQueries.to_string(),
437 "Range queries (>, <, >=, <=, !=, <>)"
438 );
439 }
440
441 #[test]
442 fn test_validation_result_equality() {
443 let result1 = SelectValidationResult {
444 has_partition_key_filter: true,
445 has_clustering_filters: false,
446 has_limit: true,
447 unsupported_features: vec![],
448 };
449
450 let result2 = SelectValidationResult {
451 has_partition_key_filter: true,
452 has_clustering_filters: false,
453 has_limit: true,
454 unsupported_features: vec![],
455 };
456
457 assert_eq!(result1, result2);
458 }
459
460 #[test]
461 fn test_all_aggregate_functions() {
462 let validator = M2SelectValidator;
463
464 for aggregate in &["COUNT", "SUM", "AVG", "MIN", "MAX"] {
465 let cql = format!("SELECT {}(*) FROM users WHERE user_id = 123", aggregate);
466 let result = validator.validate_select(&cql);
467 assert!(result.is_err(), "Should detect {} aggregate", aggregate);
468 }
469 }
470
471 #[test]
472 fn test_all_join_types() {
473 let validator = M2SelectValidator;
474
475 for join_type in &[
476 "JOIN",
477 "INNER JOIN",
478 "LEFT JOIN",
479 "RIGHT JOIN",
480 "FULL JOIN",
481 "CROSS JOIN",
482 ] {
483 let cql = format!(
484 "SELECT * FROM users {} orders ON users.id = orders.user_id",
485 join_type
486 );
487 let result = validator.validate_select(&cql);
488 assert!(result.is_err(), "Should detect {} join", join_type);
489 }
490 }
491
492 #[test]
493 fn test_all_range_operators() {
494 let validator = M2SelectValidator;
495
496 for operator in &[">", "<", ">=", "<=", "!=", "<>"] {
497 let cql = format!("SELECT * FROM users WHERE age {} 18", operator);
498 let result = validator.validate_select(&cql);
499 assert!(
500 result.is_err(),
501 "Should detect range operator: {}",
502 operator
503 );
504 }
505 }
506
507 #[test]
508 fn test_select_without_where() {
509 let validator = M2SelectValidator;
510 let cql = "SELECT * FROM users";
511
512 let result = validator.validate_select(cql).unwrap();
513
514 assert!(!result.has_partition_key_filter);
515 assert!(!result.has_clustering_filters);
516 assert!(!result.has_limit);
517 assert!(result.unsupported_features.is_empty());
518 }
519
520 #[test]
521 fn test_complex_valid_query() {
522 let validator = M2SelectValidator;
523 let cql = "SELECT user_id, name, email FROM users \
524 WHERE user_id = 123 AND status = 'active' LIMIT 100";
525
526 let result = validator.validate_select(cql).unwrap();
527
528 assert!(result.has_partition_key_filter);
529 assert!(result.has_clustering_filters);
530 assert!(result.has_limit);
531 assert!(result.unsupported_features.is_empty());
532 }
533}