1use crate::expressions::condition::parse_raw_path;
7use crate::expressions::tokenizer::{
8 Token, TokenSpan, TokenStream, check_redundant_parens, tokenize,
9};
10use crate::expressions::{PathElement, TrackedExpressionAttributes};
11use crate::types::AttributeValue;
12
13#[derive(Debug)]
15pub struct KeyCondition {
16 pub pk_name: String,
18 pub pk_value_ref: String,
20 pub sk_condition: Option<SortKeyCondition>,
22}
23
24#[derive(Debug)]
26pub enum SortKeyCondition {
27 Eq(String, String), Lt(String, String),
29 Le(String, String),
30 Gt(String, String),
31 Ge(String, String),
32 Between(String, String, String), BeginsWith(String, String), }
35
36pub fn parse(expr: &str, tracker: &TrackedExpressionAttributes) -> Result<KeyCondition, String> {
41 let tokens = tokenize(expr).map_err(|e| format!("Invalid KeyConditionExpression: {e}"))?;
42 check_redundant_parens(&tokens).map_err(|e| format!("Invalid KeyConditionExpression: {e}"))?;
45 let tokens = strip_outer_parens(tokens);
46 let mut stream = TokenStream::new(tokens);
47
48 let cond1 = parse_single_condition(&mut stream, tracker)?;
49
50 let (pk_cond, sk_cond) = if matches!(stream.peek(), Some(Token::And)) {
51 stream.next();
52 let cond2 = parse_single_condition(&mut stream, tracker)?;
53 match (cond1, cond2) {
54 (ParsedCond::Eq(n1, v1), c2) => ((n1, v1), Some(c2)),
55 (c1, ParsedCond::Eq(n2, v2)) => ((n2, v2), Some(c1)),
56 _ => {
57 return Err(
58 "Invalid KeyConditionExpression: partition key must use equality".to_string(),
59 );
60 }
61 }
62 } else {
63 match cond1 {
64 ParsedCond::Eq(name, val_ref) => ((name, val_ref), None),
65 _ => {
66 return Err(
67 "Invalid KeyConditionExpression: partition key must use equality".to_string(),
68 );
69 }
70 }
71 };
72
73 if !stream.at_end() {
74 return Err(format!(
75 "Unexpected token in KeyConditionExpression: {}",
76 stream.peek().unwrap()
77 ));
78 }
79
80 let (pk_name, pk_value_ref) = pk_cond;
81 let sk_condition = sk_cond.map(|c| c.into_sk_condition()).transpose()?;
82
83 Ok(KeyCondition {
84 pk_name,
85 pk_value_ref,
86 sk_condition,
87 })
88}
89
90pub fn resolve_values(
92 condition: &KeyCondition,
93 tracker: &TrackedExpressionAttributes,
94) -> Result<ResolvedKeyCondition, String> {
95 let pk_val = tracker.resolve_value(&condition.pk_value_ref)?.clone();
96
97 let sk = if let Some(ref sk_cond) = condition.sk_condition {
98 Some(resolve_sk_condition(sk_cond, tracker)?)
99 } else {
100 None
101 };
102
103 Ok(ResolvedKeyCondition {
104 pk_name: condition.pk_name.clone(),
105 pk_value: pk_val,
106 sk_condition: sk,
107 })
108}
109
110#[derive(Debug)]
112pub struct ResolvedKeyCondition {
113 pub pk_name: String,
114 pub pk_value: AttributeValue,
115 pub sk_condition: Option<ResolvedSortKeyCondition>,
116}
117
118#[derive(Debug)]
119pub enum ResolvedSortKeyCondition {
120 Eq(String, AttributeValue),
121 Lt(String, AttributeValue),
122 Le(String, AttributeValue),
123 Gt(String, AttributeValue),
124 Ge(String, AttributeValue),
125 Between(String, AttributeValue, AttributeValue),
126 BeginsWith(String, AttributeValue),
127}
128
129impl ResolvedSortKeyCondition {
130 pub fn sk_name(&self) -> &str {
131 match self {
132 Self::Eq(n, _)
133 | Self::Lt(n, _)
134 | Self::Le(n, _)
135 | Self::Gt(n, _)
136 | Self::Ge(n, _)
137 | Self::Between(n, _, _)
138 | Self::BeginsWith(n, _) => n,
139 }
140 }
141
142 pub fn to_sql_conditions(&self) -> Vec<(String, String)> {
146 match self {
147 Self::Eq(_, v) => vec![("=".into(), val_to_key_string(v))],
148 Self::Lt(_, v) => vec![("<".into(), val_to_key_string(v))],
149 Self::Le(_, v) => vec![("<=".into(), val_to_key_string(v))],
150 Self::Gt(_, v) => vec![(">".into(), val_to_key_string(v))],
151 Self::Ge(_, v) => vec![(">=".into(), val_to_key_string(v))],
152 Self::Between(_, lo, hi) => vec![
153 (">=".into(), val_to_key_string(lo)),
154 ("<=".into(), val_to_key_string(hi)),
155 ],
156 Self::BeginsWith(_, prefix) => {
157 let prefix_str = val_to_key_string(prefix);
158 let escaped = prefix_str
160 .replace('\\', "\\\\")
161 .replace('%', "\\%")
162 .replace('_', "\\_");
163 vec![("LIKE".into(), format!("{escaped}%"))]
164 }
165 }
166 }
167}
168
169fn val_to_key_string(val: &AttributeValue) -> String {
170 val.to_key_string().unwrap_or_default()
171}
172
173fn resolve_sk_condition(
174 cond: &SortKeyCondition,
175 tracker: &TrackedExpressionAttributes,
176) -> Result<ResolvedSortKeyCondition, String> {
177 match cond {
178 SortKeyCondition::Eq(sk, vr) => {
179 let v = tracker.resolve_value(vr)?.clone();
180 Ok(ResolvedSortKeyCondition::Eq(sk.clone(), v))
181 }
182 SortKeyCondition::Lt(sk, vr) => {
183 let v = tracker.resolve_value(vr)?.clone();
184 Ok(ResolvedSortKeyCondition::Lt(sk.clone(), v))
185 }
186 SortKeyCondition::Le(sk, vr) => {
187 let v = tracker.resolve_value(vr)?.clone();
188 Ok(ResolvedSortKeyCondition::Le(sk.clone(), v))
189 }
190 SortKeyCondition::Gt(sk, vr) => {
191 let v = tracker.resolve_value(vr)?.clone();
192 Ok(ResolvedSortKeyCondition::Gt(sk.clone(), v))
193 }
194 SortKeyCondition::Ge(sk, vr) => {
195 let v = tracker.resolve_value(vr)?.clone();
196 Ok(ResolvedSortKeyCondition::Ge(sk.clone(), v))
197 }
198 SortKeyCondition::Between(sk, lo_ref, hi_ref) => {
199 let lo = tracker.resolve_value(lo_ref)?.clone();
200 let hi = tracker.resolve_value(hi_ref)?.clone();
201 if std::mem::discriminant(&lo) != std::mem::discriminant(&hi) {
203 return Err(format!(
204 "Invalid KeyConditionExpression: The BETWEEN operator requires same data type \
205 for lower and upper bounds; lower bound operand: AttributeValue: {{{}}}, \
206 upper bound operand: AttributeValue: {{{}}}",
207 format_attr_value_short(&lo),
208 format_attr_value_short(&hi)
209 ));
210 }
211 if !between_order_valid(&lo, &hi) {
213 return Err(format!(
214 "Invalid KeyConditionExpression: The BETWEEN operator requires upper bound \
215 to be greater than or equal to lower bound; lower bound operand: \
216 AttributeValue: {{{}}}, upper bound operand: AttributeValue: {{{}}}",
217 format_attr_value_short(&lo),
218 format_attr_value_short(&hi)
219 ));
220 }
221 Ok(ResolvedSortKeyCondition::Between(sk.clone(), lo, hi))
222 }
223 SortKeyCondition::BeginsWith(sk, vr) => {
224 let v = tracker.resolve_value(vr)?.clone();
225 Ok(ResolvedSortKeyCondition::BeginsWith(sk.clone(), v))
226 }
227 }
228}
229
230#[derive(Debug)]
235enum ParsedCond {
236 Eq(String, String), Lt(String, String),
238 Le(String, String),
239 Gt(String, String),
240 Ge(String, String),
241 Between(String, String, String), BeginsWith(String, String), }
244
245impl ParsedCond {
246 fn into_sk_condition(self) -> Result<SortKeyCondition, String> {
247 match self {
248 ParsedCond::Eq(n, v) => Ok(SortKeyCondition::Eq(n, v)),
249 ParsedCond::Lt(n, v) => Ok(SortKeyCondition::Lt(n, v)),
250 ParsedCond::Le(n, v) => Ok(SortKeyCondition::Le(n, v)),
251 ParsedCond::Gt(n, v) => Ok(SortKeyCondition::Gt(n, v)),
252 ParsedCond::Ge(n, v) => Ok(SortKeyCondition::Ge(n, v)),
253 ParsedCond::Between(n, lo, hi) => Ok(SortKeyCondition::Between(n, lo, hi)),
254 ParsedCond::BeginsWith(n, v) => Ok(SortKeyCondition::BeginsWith(n, v)),
255 }
256 }
257}
258
259fn strip_outer_parens(mut tokens: Vec<(Token, TokenSpan)>) -> Vec<(Token, TokenSpan)> {
264 loop {
265 if tokens.len() < 2 {
266 break;
267 }
268 if !matches!(tokens.first().map(|(t, _)| t), Some(Token::LParen)) {
269 break;
270 }
271 let mut depth = 0;
274 let mut close_pos = None;
275 for (i, (tok, _)) in tokens.iter().enumerate() {
276 match tok {
277 Token::LParen => depth += 1,
278 Token::RParen => {
279 depth -= 1;
280 if depth == 0 {
281 close_pos = Some(i);
282 break;
283 }
284 }
285 _ => {}
286 }
287 }
288 if close_pos == Some(tokens.len() - 1) {
289 tokens.remove(tokens.len() - 1);
291 tokens.remove(0);
292 } else {
293 break;
294 }
295 }
296 tokens
297}
298
299fn parse_single_condition(
300 stream: &mut TokenStream,
301 tracker: &TrackedExpressionAttributes,
302) -> Result<ParsedCond, String> {
303 let mut parens = 0;
305 while matches!(stream.peek(), Some(Token::LParen)) {
306 stream.next();
307 parens += 1;
308 }
309
310 if let Some(Token::Identifier(name)) = stream.peek() {
312 if name.to_lowercase() == "begins_with" {
313 stream.next();
314 stream.expect(&Token::LParen)?;
315 let path = parse_raw_path(stream)?;
316 let attr_name = resolve_path_to_name(&path, tracker)?;
317 stream.expect(&Token::Comma)?;
318 let val_ref = expect_value_ref(stream)?;
319 stream.expect(&Token::RParen)?;
320 consume_close_parens(stream, parens)?;
321 return Ok(ParsedCond::BeginsWith(attr_name, val_ref));
322 }
323 }
324
325 let path = parse_raw_path(stream)?;
327 let attr_name = resolve_path_to_name(&path, tracker)?;
328
329 let result = match stream.next() {
330 Some(Token::Eq) => {
331 let val_ref = expect_value_ref(stream)?;
332 Ok(ParsedCond::Eq(attr_name, val_ref))
333 }
334 Some(Token::Lt) => {
335 let val_ref = expect_value_ref(stream)?;
336 Ok(ParsedCond::Lt(attr_name, val_ref))
337 }
338 Some(Token::Le) => {
339 let val_ref = expect_value_ref(stream)?;
340 Ok(ParsedCond::Le(attr_name, val_ref))
341 }
342 Some(Token::Gt) => {
343 let val_ref = expect_value_ref(stream)?;
344 Ok(ParsedCond::Gt(attr_name, val_ref))
345 }
346 Some(Token::Ge) => {
347 let val_ref = expect_value_ref(stream)?;
348 Ok(ParsedCond::Ge(attr_name, val_ref))
349 }
350 Some(Token::Between) => {
351 let lo_ref = expect_value_ref(stream)?;
352 stream.expect(&Token::And)?;
353 let hi_ref = expect_value_ref(stream)?;
354 Ok(ParsedCond::Between(attr_name, lo_ref, hi_ref))
355 }
356 Some(t) => Err(format!(
357 "Unexpected operator in KeyConditionExpression: {t}"
358 )),
359 None => Err("Unexpected end of KeyConditionExpression".to_string()),
360 };
361
362 consume_close_parens(stream, parens)?;
363 result
364}
365
366fn consume_close_parens(stream: &mut TokenStream, count: usize) -> Result<(), String> {
368 for _ in 0..count {
369 match stream.next() {
370 Some(Token::RParen) => {}
371 Some(t) => {
372 return Err(format!(
373 "Expected closing parenthesis in KeyConditionExpression, got {t}"
374 ));
375 }
376 None => {
377 return Err(
378 "Unexpected end of KeyConditionExpression, expected closing parenthesis"
379 .to_string(),
380 );
381 }
382 }
383 }
384 Ok(())
385}
386
387fn resolve_path_to_name(
388 path: &[PathElement],
389 tracker: &TrackedExpressionAttributes,
390) -> Result<String, String> {
391 if path.len() != 1 {
392 return Err("KeyConditionExpression only supports top-level attributes".to_string());
393 }
394 match &path[0] {
395 PathElement::Attribute(name) => {
396 if name.starts_with('#') {
397 tracker.resolve_name(name)
398 } else {
399 Ok(name.clone())
400 }
401 }
402 PathElement::Index(_) => Err("KeyConditionExpression cannot use index paths".to_string()),
403 }
404}
405
406fn format_attr_value_short(val: &AttributeValue) -> String {
408 match val {
409 AttributeValue::S(s) => format!("S:{s}"),
410 AttributeValue::N(n) => format!("N:{n}"),
411 AttributeValue::B(b) => {
412 use base64::Engine;
413 let encoded = base64::engine::general_purpose::STANDARD.encode(b);
414 format!("B:{encoded}")
415 }
416 AttributeValue::BOOL(b) => format!("BOOL:{b}"),
417 AttributeValue::NULL(_) => "NULL:true".to_string(),
418 AttributeValue::SS(set) => format!("SS:{:?}", set),
419 AttributeValue::NS(set) => format!("NS:{:?}", set),
420 AttributeValue::BS(_) => "BS:[...]".to_string(),
421 AttributeValue::L(_) => "L:[...]".to_string(),
422 AttributeValue::M(_) => "M:{...}".to_string(),
423 }
424}
425
426fn between_order_valid(lo: &AttributeValue, hi: &AttributeValue) -> bool {
428 match (lo, hi) {
429 (AttributeValue::S(a), AttributeValue::S(b)) => a <= b,
430 (AttributeValue::N(a), AttributeValue::N(b)) => {
431 let a_f = a.parse::<f64>().unwrap_or(0.0);
432 let b_f = b.parse::<f64>().unwrap_or(0.0);
433 a_f <= b_f
434 }
435 (AttributeValue::B(a), AttributeValue::B(b)) => a <= b,
436 _ => true,
437 }
438}
439
440fn expect_value_ref(stream: &mut TokenStream) -> Result<String, String> {
441 match stream.next() {
442 Some(Token::ValueRef(name)) => Ok(name.clone()),
443 Some(t) => Err(format!("Expected value reference (:name), got {t}")),
444 None => Err("Expected value reference, got end of expression".to_string()),
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use std::collections::HashMap;
452
453 fn make_tracker<'a>(
454 names: &'a Option<HashMap<String, String>>,
455 values: &'a Option<HashMap<String, AttributeValue>>,
456 ) -> TrackedExpressionAttributes<'a> {
457 TrackedExpressionAttributes::new(names, values)
458 }
459
460 #[test]
461 fn test_pk_only() {
462 let no_names = None;
463 let no_values = None;
464 let tracker = make_tracker(&no_names, &no_values);
465 let kc = parse("pk = :pk", &tracker).unwrap();
466 assert_eq!(kc.pk_name, "pk");
467 assert_eq!(kc.pk_value_ref, ":pk");
468 assert!(kc.sk_condition.is_none());
469 }
470
471 #[test]
472 fn test_pk_and_sk_eq() {
473 let no_names = None;
474 let no_values = None;
475 let tracker = make_tracker(&no_names, &no_values);
476 let kc = parse("pk = :pk AND sk = :sk", &tracker).unwrap();
477 assert_eq!(kc.pk_name, "pk");
478 assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Eq(_, _))));
479 }
480
481 #[test]
482 fn test_pk_and_sk_between() {
483 let no_names = None;
484 let no_values = None;
485 let tracker = make_tracker(&no_names, &no_values);
486 let kc = parse("pk = :pk AND sk BETWEEN :lo AND :hi", &tracker).unwrap();
487 assert!(matches!(
488 kc.sk_condition,
489 Some(SortKeyCondition::Between(_, _, _))
490 ));
491 }
492
493 #[test]
494 fn test_pk_and_begins_with() {
495 let no_names = None;
496 let no_values = None;
497 let tracker = make_tracker(&no_names, &no_values);
498 let kc = parse("pk = :pk AND begins_with(sk, :prefix)", &tracker).unwrap();
499 assert!(matches!(
500 kc.sk_condition,
501 Some(SortKeyCondition::BeginsWith(_, _))
502 ));
503 }
504
505 #[test]
506 fn test_with_attribute_names() {
507 let an = Some(HashMap::from([
508 ("#pk".to_string(), "partitionKey".to_string()),
509 ("#sk".to_string(), "sortKey".to_string()),
510 ]));
511 let no_values = None;
512 let tracker = make_tracker(&an, &no_values);
513 let kc = parse("#pk = :pk AND #sk > :sk", &tracker).unwrap();
514 assert_eq!(kc.pk_name, "partitionKey");
515 assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Gt(ref n, _)) if n == "sortKey"));
516 }
517
518 #[test]
519 fn test_resolve_values() {
520 let no_names = None;
521 let no_values = None;
522 let parse_tracker = make_tracker(&no_names, &no_values);
523 let kc = parse("pk = :pk AND sk >= :sk", &parse_tracker).unwrap();
524 let av = Some(HashMap::from([
525 (":pk".to_string(), AttributeValue::S("user#1".into())),
526 (":sk".to_string(), AttributeValue::S("2024-01-01".into())),
527 ]));
528 let resolve_tracker = make_tracker(&no_names, &av);
529 let resolved = resolve_values(&kc, &resolve_tracker).unwrap();
530 assert_eq!(resolved.pk_value, AttributeValue::S("user#1".into()));
531 assert!(matches!(
532 resolved.sk_condition,
533 Some(ResolvedSortKeyCondition::Ge(_, _))
534 ));
535 }
536
537 #[test]
538 fn test_parenthesized_conditions() {
539 let no_names = None;
540 let no_values = None;
541
542 let tracker = make_tracker(&no_names, &no_values);
544 let kc = parse("(pk = :pk) AND (sk = :sk)", &tracker).unwrap();
545 assert_eq!(kc.pk_name, "pk");
546 assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Eq(_, _))));
547
548 let tracker = make_tracker(&no_names, &no_values);
550 let kc = parse("(pk = :pk AND sk = :sk)", &tracker).unwrap();
551 assert_eq!(kc.pk_name, "pk");
552 assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Eq(_, _))));
553
554 let tracker = make_tracker(&no_names, &no_values);
556 let kc = parse("(pk = :pk AND (sk > :sk))", &tracker).unwrap();
557 assert_eq!(kc.pk_name, "pk");
558 assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Gt(_, _))));
559
560 let tracker = make_tracker(&no_names, &no_values);
563 let err = parse("((pk = :pk)) AND ((sk > :sk))", &tracker).unwrap_err();
564 assert!(
565 err.contains("redundant parentheses"),
566 "expected redundant-parentheses rejection, got: {err}"
567 );
568
569 let tracker = make_tracker(&no_names, &no_values);
571 let kc = parse("(pk = :pk) AND (begins_with(sk, :prefix))", &tracker).unwrap();
572 assert!(matches!(
573 kc.sk_condition,
574 Some(SortKeyCondition::BeginsWith(_, _))
575 ));
576
577 let an = Some(HashMap::from([
579 ("#pk".to_string(), "PK".to_string()),
580 ("#sk".to_string(), "SK".to_string()),
581 ]));
582 let tracker = make_tracker(&an, &no_values);
583 let kc = parse("(#pk = :pk) AND (#sk = :sk)", &tracker).unwrap();
584 assert_eq!(kc.pk_name, "PK");
585 }
586
587 #[test]
588 fn test_sk_comparisons() {
589 let no_names = None;
590 let no_values = None;
591 for (op, variant) in [("<", "Lt"), ("<=", "Le"), (">", "Gt"), (">=", "Ge")] {
592 let tracker = make_tracker(&no_names, &no_values);
593 let kc = parse(&format!("pk = :pk AND sk {op} :sk"), &tracker).unwrap();
594 let sk = kc.sk_condition.unwrap();
595 let name = match &sk {
596 SortKeyCondition::Lt(n, _) => format!("Lt:{n}"),
597 SortKeyCondition::Le(n, _) => format!("Le:{n}"),
598 SortKeyCondition::Gt(n, _) => format!("Gt:{n}"),
599 SortKeyCondition::Ge(n, _) => format!("Ge:{n}"),
600 _ => "other".to_string(),
601 };
602 assert!(name.starts_with(variant), "Expected {variant}, got {name}");
603 }
604 }
605}