1use crate::expressions::condition::parse_raw_path;
7use crate::expressions::tokenizer::{Token, TokenSpan, TokenStream, tokenize};
8use crate::expressions::{PathElement, TrackedExpressionAttributes};
9use crate::types::AttributeValue;
10
11#[derive(Debug)]
13pub struct KeyCondition {
14 pub pk_name: String,
16 pub pk_value_ref: String,
18 pub sk_condition: Option<SortKeyCondition>,
20}
21
22#[derive(Debug)]
24pub enum SortKeyCondition {
25 Eq(String, String), Lt(String, String),
27 Le(String, String),
28 Gt(String, String),
29 Ge(String, String),
30 Between(String, String, String), BeginsWith(String, String), }
33
34pub fn parse(expr: &str, tracker: &TrackedExpressionAttributes) -> Result<KeyCondition, String> {
39 let tokens = tokenize(expr).map_err(|e| format!("Invalid KeyConditionExpression: {e}"))?;
40 let tokens = strip_outer_parens(tokens);
41 let mut stream = TokenStream::new(tokens);
42
43 let cond1 = parse_single_condition(&mut stream, tracker)?;
44
45 let (pk_cond, sk_cond) = if matches!(stream.peek(), Some(Token::And)) {
46 stream.next();
47 let cond2 = parse_single_condition(&mut stream, tracker)?;
48 match (cond1, cond2) {
49 (ParsedCond::Eq(n1, v1), c2) => ((n1, v1), Some(c2)),
50 (c1, ParsedCond::Eq(n2, v2)) => ((n2, v2), Some(c1)),
51 _ => {
52 return Err(
53 "Invalid KeyConditionExpression: partition key must use equality".to_string(),
54 );
55 }
56 }
57 } else {
58 match cond1 {
59 ParsedCond::Eq(name, val_ref) => ((name, val_ref), None),
60 _ => {
61 return Err(
62 "Invalid KeyConditionExpression: partition key must use equality".to_string(),
63 );
64 }
65 }
66 };
67
68 if !stream.at_end() {
69 return Err(format!(
70 "Unexpected token in KeyConditionExpression: {}",
71 stream.peek().unwrap()
72 ));
73 }
74
75 let (pk_name, pk_value_ref) = pk_cond;
76 let sk_condition = sk_cond.map(|c| c.into_sk_condition()).transpose()?;
77
78 Ok(KeyCondition {
79 pk_name,
80 pk_value_ref,
81 sk_condition,
82 })
83}
84
85pub fn resolve_values(
87 condition: &KeyCondition,
88 tracker: &TrackedExpressionAttributes,
89) -> Result<ResolvedKeyCondition, String> {
90 let pk_val = tracker.resolve_value(&condition.pk_value_ref)?.clone();
91
92 let sk = if let Some(ref sk_cond) = condition.sk_condition {
93 Some(resolve_sk_condition(sk_cond, tracker)?)
94 } else {
95 None
96 };
97
98 Ok(ResolvedKeyCondition {
99 pk_name: condition.pk_name.clone(),
100 pk_value: pk_val,
101 sk_condition: sk,
102 })
103}
104
105#[derive(Debug)]
107pub struct ResolvedKeyCondition {
108 pub pk_name: String,
109 pub pk_value: AttributeValue,
110 pub sk_condition: Option<ResolvedSortKeyCondition>,
111}
112
113#[derive(Debug)]
114pub enum ResolvedSortKeyCondition {
115 Eq(String, AttributeValue),
116 Lt(String, AttributeValue),
117 Le(String, AttributeValue),
118 Gt(String, AttributeValue),
119 Ge(String, AttributeValue),
120 Between(String, AttributeValue, AttributeValue),
121 BeginsWith(String, AttributeValue),
122}
123
124impl ResolvedSortKeyCondition {
125 pub fn sk_name(&self) -> &str {
126 match self {
127 Self::Eq(n, _)
128 | Self::Lt(n, _)
129 | Self::Le(n, _)
130 | Self::Gt(n, _)
131 | Self::Ge(n, _)
132 | Self::Between(n, _, _)
133 | Self::BeginsWith(n, _) => n,
134 }
135 }
136
137 pub fn to_sql_conditions(&self) -> Vec<(String, String)> {
141 match self {
142 Self::Eq(_, v) => vec![("=".into(), val_to_key_string(v))],
143 Self::Lt(_, v) => vec![("<".into(), val_to_key_string(v))],
144 Self::Le(_, v) => vec![("<=".into(), val_to_key_string(v))],
145 Self::Gt(_, v) => vec![(">".into(), val_to_key_string(v))],
146 Self::Ge(_, v) => vec![(">=".into(), val_to_key_string(v))],
147 Self::Between(_, lo, hi) => vec![
148 (">=".into(), val_to_key_string(lo)),
149 ("<=".into(), val_to_key_string(hi)),
150 ],
151 Self::BeginsWith(_, prefix) => {
152 let prefix_str = val_to_key_string(prefix);
153 let escaped = prefix_str
155 .replace('\\', "\\\\")
156 .replace('%', "\\%")
157 .replace('_', "\\_");
158 vec![("LIKE".into(), format!("{escaped}%"))]
159 }
160 }
161 }
162}
163
164fn val_to_key_string(val: &AttributeValue) -> String {
165 val.to_key_string().unwrap_or_default()
166}
167
168fn resolve_sk_condition(
169 cond: &SortKeyCondition,
170 tracker: &TrackedExpressionAttributes,
171) -> Result<ResolvedSortKeyCondition, String> {
172 match cond {
173 SortKeyCondition::Eq(sk, vr) => {
174 let v = tracker.resolve_value(vr)?.clone();
175 Ok(ResolvedSortKeyCondition::Eq(sk.clone(), v))
176 }
177 SortKeyCondition::Lt(sk, vr) => {
178 let v = tracker.resolve_value(vr)?.clone();
179 Ok(ResolvedSortKeyCondition::Lt(sk.clone(), v))
180 }
181 SortKeyCondition::Le(sk, vr) => {
182 let v = tracker.resolve_value(vr)?.clone();
183 Ok(ResolvedSortKeyCondition::Le(sk.clone(), v))
184 }
185 SortKeyCondition::Gt(sk, vr) => {
186 let v = tracker.resolve_value(vr)?.clone();
187 Ok(ResolvedSortKeyCondition::Gt(sk.clone(), v))
188 }
189 SortKeyCondition::Ge(sk, vr) => {
190 let v = tracker.resolve_value(vr)?.clone();
191 Ok(ResolvedSortKeyCondition::Ge(sk.clone(), v))
192 }
193 SortKeyCondition::Between(sk, lo_ref, hi_ref) => {
194 let lo = tracker.resolve_value(lo_ref)?.clone();
195 let hi = tracker.resolve_value(hi_ref)?.clone();
196 if std::mem::discriminant(&lo) != std::mem::discriminant(&hi) {
198 return Err(format!(
199 "Invalid KeyConditionExpression: The BETWEEN operator requires same data type \
200 for lower and upper bounds; lower bound operand: AttributeValue: {{{}}}, \
201 upper bound operand: AttributeValue: {{{}}}",
202 format_attr_value_short(&lo),
203 format_attr_value_short(&hi)
204 ));
205 }
206 if !between_order_valid(&lo, &hi) {
208 return Err(format!(
209 "Invalid KeyConditionExpression: The BETWEEN operator requires upper bound \
210 to be greater than or equal to lower bound; lower bound operand: \
211 AttributeValue: {{{}}}, upper bound operand: AttributeValue: {{{}}}",
212 format_attr_value_short(&lo),
213 format_attr_value_short(&hi)
214 ));
215 }
216 Ok(ResolvedSortKeyCondition::Between(sk.clone(), lo, hi))
217 }
218 SortKeyCondition::BeginsWith(sk, vr) => {
219 let v = tracker.resolve_value(vr)?.clone();
220 Ok(ResolvedSortKeyCondition::BeginsWith(sk.clone(), v))
221 }
222 }
223}
224
225#[derive(Debug)]
230enum ParsedCond {
231 Eq(String, String), Lt(String, String),
233 Le(String, String),
234 Gt(String, String),
235 Ge(String, String),
236 Between(String, String, String), BeginsWith(String, String), }
239
240impl ParsedCond {
241 fn into_sk_condition(self) -> Result<SortKeyCondition, String> {
242 match self {
243 ParsedCond::Eq(n, v) => Ok(SortKeyCondition::Eq(n, v)),
244 ParsedCond::Lt(n, v) => Ok(SortKeyCondition::Lt(n, v)),
245 ParsedCond::Le(n, v) => Ok(SortKeyCondition::Le(n, v)),
246 ParsedCond::Gt(n, v) => Ok(SortKeyCondition::Gt(n, v)),
247 ParsedCond::Ge(n, v) => Ok(SortKeyCondition::Ge(n, v)),
248 ParsedCond::Between(n, lo, hi) => Ok(SortKeyCondition::Between(n, lo, hi)),
249 ParsedCond::BeginsWith(n, v) => Ok(SortKeyCondition::BeginsWith(n, v)),
250 }
251 }
252}
253
254fn strip_outer_parens(mut tokens: Vec<(Token, TokenSpan)>) -> Vec<(Token, TokenSpan)> {
259 loop {
260 if tokens.len() < 2 {
261 break;
262 }
263 if !matches!(tokens.first().map(|(t, _)| t), Some(Token::LParen)) {
264 break;
265 }
266 let mut depth = 0;
269 let mut close_pos = None;
270 for (i, (tok, _)) in tokens.iter().enumerate() {
271 match tok {
272 Token::LParen => depth += 1,
273 Token::RParen => {
274 depth -= 1;
275 if depth == 0 {
276 close_pos = Some(i);
277 break;
278 }
279 }
280 _ => {}
281 }
282 }
283 if close_pos == Some(tokens.len() - 1) {
284 tokens.remove(tokens.len() - 1);
286 tokens.remove(0);
287 } else {
288 break;
289 }
290 }
291 tokens
292}
293
294fn parse_single_condition(
295 stream: &mut TokenStream,
296 tracker: &TrackedExpressionAttributes,
297) -> Result<ParsedCond, String> {
298 let mut parens = 0;
300 while matches!(stream.peek(), Some(Token::LParen)) {
301 stream.next();
302 parens += 1;
303 }
304
305 if let Some(Token::Identifier(name)) = stream.peek() {
307 if name.to_lowercase() == "begins_with" {
308 stream.next();
309 stream.expect(&Token::LParen)?;
310 let path = parse_raw_path(stream)?;
311 let attr_name = resolve_path_to_name(&path, tracker)?;
312 stream.expect(&Token::Comma)?;
313 let val_ref = expect_value_ref(stream)?;
314 stream.expect(&Token::RParen)?;
315 consume_close_parens(stream, parens)?;
316 return Ok(ParsedCond::BeginsWith(attr_name, val_ref));
317 }
318 }
319
320 let path = parse_raw_path(stream)?;
322 let attr_name = resolve_path_to_name(&path, tracker)?;
323
324 let result = match stream.next() {
325 Some(Token::Eq) => {
326 let val_ref = expect_value_ref(stream)?;
327 Ok(ParsedCond::Eq(attr_name, val_ref))
328 }
329 Some(Token::Lt) => {
330 let val_ref = expect_value_ref(stream)?;
331 Ok(ParsedCond::Lt(attr_name, val_ref))
332 }
333 Some(Token::Le) => {
334 let val_ref = expect_value_ref(stream)?;
335 Ok(ParsedCond::Le(attr_name, val_ref))
336 }
337 Some(Token::Gt) => {
338 let val_ref = expect_value_ref(stream)?;
339 Ok(ParsedCond::Gt(attr_name, val_ref))
340 }
341 Some(Token::Ge) => {
342 let val_ref = expect_value_ref(stream)?;
343 Ok(ParsedCond::Ge(attr_name, val_ref))
344 }
345 Some(Token::Between) => {
346 let lo_ref = expect_value_ref(stream)?;
347 stream.expect(&Token::And)?;
348 let hi_ref = expect_value_ref(stream)?;
349 Ok(ParsedCond::Between(attr_name, lo_ref, hi_ref))
350 }
351 Some(t) => Err(format!(
352 "Unexpected operator in KeyConditionExpression: {t}"
353 )),
354 None => Err("Unexpected end of KeyConditionExpression".to_string()),
355 };
356
357 consume_close_parens(stream, parens)?;
358 result
359}
360
361fn consume_close_parens(stream: &mut TokenStream, count: usize) -> Result<(), String> {
363 for _ in 0..count {
364 match stream.next() {
365 Some(Token::RParen) => {}
366 Some(t) => {
367 return Err(format!(
368 "Expected closing parenthesis in KeyConditionExpression, got {t}"
369 ));
370 }
371 None => {
372 return Err(
373 "Unexpected end of KeyConditionExpression, expected closing parenthesis"
374 .to_string(),
375 );
376 }
377 }
378 }
379 Ok(())
380}
381
382fn resolve_path_to_name(
383 path: &[PathElement],
384 tracker: &TrackedExpressionAttributes,
385) -> Result<String, String> {
386 if path.len() != 1 {
387 return Err("KeyConditionExpression only supports top-level attributes".to_string());
388 }
389 match &path[0] {
390 PathElement::Attribute(name) => {
391 if name.starts_with('#') {
392 tracker.resolve_name(name)
393 } else {
394 Ok(name.clone())
395 }
396 }
397 PathElement::Index(_) => Err("KeyConditionExpression cannot use index paths".to_string()),
398 }
399}
400
401fn format_attr_value_short(val: &AttributeValue) -> String {
403 match val {
404 AttributeValue::S(s) => format!("S:{s}"),
405 AttributeValue::N(n) => format!("N:{n}"),
406 AttributeValue::B(b) => {
407 use base64::Engine;
408 let encoded = base64::engine::general_purpose::STANDARD.encode(b);
409 format!("B:{encoded}")
410 }
411 AttributeValue::BOOL(b) => format!("BOOL:{b}"),
412 AttributeValue::NULL(_) => "NULL:true".to_string(),
413 AttributeValue::SS(set) => format!("SS:{:?}", set),
414 AttributeValue::NS(set) => format!("NS:{:?}", set),
415 AttributeValue::BS(_) => "BS:[...]".to_string(),
416 AttributeValue::L(_) => "L:[...]".to_string(),
417 AttributeValue::M(_) => "M:{...}".to_string(),
418 }
419}
420
421fn between_order_valid(lo: &AttributeValue, hi: &AttributeValue) -> bool {
423 match (lo, hi) {
424 (AttributeValue::S(a), AttributeValue::S(b)) => a <= b,
425 (AttributeValue::N(a), AttributeValue::N(b)) => {
426 let a_f = a.parse::<f64>().unwrap_or(0.0);
427 let b_f = b.parse::<f64>().unwrap_or(0.0);
428 a_f <= b_f
429 }
430 (AttributeValue::B(a), AttributeValue::B(b)) => a <= b,
431 _ => true,
432 }
433}
434
435fn expect_value_ref(stream: &mut TokenStream) -> Result<String, String> {
436 match stream.next() {
437 Some(Token::ValueRef(name)) => Ok(name.clone()),
438 Some(t) => Err(format!("Expected value reference (:name), got {t}")),
439 None => Err("Expected value reference, got end of expression".to_string()),
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 use std::collections::HashMap;
447
448 fn make_tracker<'a>(
449 names: &'a Option<HashMap<String, String>>,
450 values: &'a Option<HashMap<String, AttributeValue>>,
451 ) -> TrackedExpressionAttributes<'a> {
452 TrackedExpressionAttributes::new(names, values)
453 }
454
455 #[test]
456 fn test_pk_only() {
457 let no_names = None;
458 let no_values = None;
459 let tracker = make_tracker(&no_names, &no_values);
460 let kc = parse("pk = :pk", &tracker).unwrap();
461 assert_eq!(kc.pk_name, "pk");
462 assert_eq!(kc.pk_value_ref, ":pk");
463 assert!(kc.sk_condition.is_none());
464 }
465
466 #[test]
467 fn test_pk_and_sk_eq() {
468 let no_names = None;
469 let no_values = None;
470 let tracker = make_tracker(&no_names, &no_values);
471 let kc = parse("pk = :pk AND sk = :sk", &tracker).unwrap();
472 assert_eq!(kc.pk_name, "pk");
473 assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Eq(_, _))));
474 }
475
476 #[test]
477 fn test_pk_and_sk_between() {
478 let no_names = None;
479 let no_values = None;
480 let tracker = make_tracker(&no_names, &no_values);
481 let kc = parse("pk = :pk AND sk BETWEEN :lo AND :hi", &tracker).unwrap();
482 assert!(matches!(
483 kc.sk_condition,
484 Some(SortKeyCondition::Between(_, _, _))
485 ));
486 }
487
488 #[test]
489 fn test_pk_and_begins_with() {
490 let no_names = None;
491 let no_values = None;
492 let tracker = make_tracker(&no_names, &no_values);
493 let kc = parse("pk = :pk AND begins_with(sk, :prefix)", &tracker).unwrap();
494 assert!(matches!(
495 kc.sk_condition,
496 Some(SortKeyCondition::BeginsWith(_, _))
497 ));
498 }
499
500 #[test]
501 fn test_with_attribute_names() {
502 let an = Some(HashMap::from([
503 ("#pk".to_string(), "partitionKey".to_string()),
504 ("#sk".to_string(), "sortKey".to_string()),
505 ]));
506 let no_values = None;
507 let tracker = make_tracker(&an, &no_values);
508 let kc = parse("#pk = :pk AND #sk > :sk", &tracker).unwrap();
509 assert_eq!(kc.pk_name, "partitionKey");
510 assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Gt(ref n, _)) if n == "sortKey"));
511 }
512
513 #[test]
514 fn test_resolve_values() {
515 let no_names = None;
516 let no_values = None;
517 let parse_tracker = make_tracker(&no_names, &no_values);
518 let kc = parse("pk = :pk AND sk >= :sk", &parse_tracker).unwrap();
519 let av = Some(HashMap::from([
520 (":pk".to_string(), AttributeValue::S("user#1".into())),
521 (":sk".to_string(), AttributeValue::S("2024-01-01".into())),
522 ]));
523 let resolve_tracker = make_tracker(&no_names, &av);
524 let resolved = resolve_values(&kc, &resolve_tracker).unwrap();
525 assert_eq!(resolved.pk_value, AttributeValue::S("user#1".into()));
526 assert!(matches!(
527 resolved.sk_condition,
528 Some(ResolvedSortKeyCondition::Ge(_, _))
529 ));
530 }
531
532 #[test]
533 fn test_parenthesized_conditions() {
534 let no_names = None;
535 let no_values = None;
536
537 let tracker = make_tracker(&no_names, &no_values);
539 let kc = parse("(pk = :pk) AND (sk = :sk)", &tracker).unwrap();
540 assert_eq!(kc.pk_name, "pk");
541 assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Eq(_, _))));
542
543 let tracker = make_tracker(&no_names, &no_values);
545 let kc = parse("(pk = :pk AND sk = :sk)", &tracker).unwrap();
546 assert_eq!(kc.pk_name, "pk");
547 assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Eq(_, _))));
548
549 let tracker = make_tracker(&no_names, &no_values);
551 let kc = parse("((pk = :pk)) AND ((sk > :sk))", &tracker).unwrap();
552 assert_eq!(kc.pk_name, "pk");
553 assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Gt(_, _))));
554
555 let tracker = make_tracker(&no_names, &no_values);
557 let kc = parse("(((pk = :pk)) AND ((begins_with(sk, :prefix))))", &tracker).unwrap();
558 assert_eq!(kc.pk_name, "pk");
559 assert!(matches!(
560 kc.sk_condition,
561 Some(SortKeyCondition::BeginsWith(_, _))
562 ));
563
564 let tracker = make_tracker(&no_names, &no_values);
566 let kc = parse("(pk = :pk) AND (begins_with(sk, :prefix))", &tracker).unwrap();
567 assert!(matches!(
568 kc.sk_condition,
569 Some(SortKeyCondition::BeginsWith(_, _))
570 ));
571
572 let an = Some(HashMap::from([
574 ("#pk".to_string(), "PK".to_string()),
575 ("#sk".to_string(), "SK".to_string()),
576 ]));
577 let tracker = make_tracker(&an, &no_values);
578 let kc = parse("(#pk = :pk) AND (#sk = :sk)", &tracker).unwrap();
579 assert_eq!(kc.pk_name, "PK");
580 }
581
582 #[test]
583 fn test_sk_comparisons() {
584 let no_names = None;
585 let no_values = None;
586 for (op, variant) in [("<", "Lt"), ("<=", "Le"), (">", "Gt"), (">=", "Ge")] {
587 let tracker = make_tracker(&no_names, &no_values);
588 let kc = parse(&format!("pk = :pk AND sk {op} :sk"), &tracker).unwrap();
589 let sk = kc.sk_condition.unwrap();
590 let name = match &sk {
591 SortKeyCondition::Lt(n, _) => format!("Lt:{n}"),
592 SortKeyCondition::Le(n, _) => format!("Le:{n}"),
593 SortKeyCondition::Gt(n, _) => format!("Gt:{n}"),
594 SortKeyCondition::Ge(n, _) => format!("Ge:{n}"),
595 _ => "other".to_string(),
596 };
597 assert!(name.starts_with(variant), "Expected {variant}, got {name}");
598 }
599 }
600}