1use crate::expressions::condition::parse_raw_path;
7use crate::expressions::tokenizer::{Token, TokenStream, tokenize};
8use crate::expressions::{PathElement, TrackedExpressionAttributes};
9use crate::types::AttributeValue;
10
11#[derive(Debug)]
13pub struct KeyCondition {
14 pub pk_name: String,
16 pub pk_value_ref: String,
18 pub sk_condition: Option<SortKeyCondition>,
20}
21
22#[derive(Debug)]
24pub enum SortKeyCondition {
25 Eq(String, String), Lt(String, String),
27 Le(String, String),
28 Gt(String, String),
29 Ge(String, String),
30 Between(String, String, String), BeginsWith(String, String), }
33
34pub fn parse(expr: &str, tracker: &TrackedExpressionAttributes) -> Result<KeyCondition, String> {
36 let tokens = tokenize(expr).map_err(|e| format!("Invalid KeyConditionExpression: {e}"))?;
37 let mut stream = TokenStream::new(tokens);
38
39 let cond1 = parse_single_condition(&mut stream, tracker)?;
40
41 let (pk_cond, sk_cond) = if matches!(stream.peek(), Some(Token::And)) {
42 stream.next();
43 let cond2 = parse_single_condition(&mut stream, tracker)?;
44 match (cond1, cond2) {
45 (ParsedCond::Eq(n1, v1), c2) => ((n1, v1), Some(c2)),
46 (c1, ParsedCond::Eq(n2, v2)) => ((n2, v2), Some(c1)),
47 _ => {
48 return Err(
49 "Invalid KeyConditionExpression: partition key must use equality".to_string(),
50 );
51 }
52 }
53 } else {
54 match cond1 {
55 ParsedCond::Eq(name, val_ref) => ((name, val_ref), None),
56 _ => {
57 return Err(
58 "Invalid KeyConditionExpression: partition key must use equality".to_string(),
59 );
60 }
61 }
62 };
63
64 if !stream.at_end() {
65 return Err(format!(
66 "Unexpected token in KeyConditionExpression: {}",
67 stream.peek().unwrap()
68 ));
69 }
70
71 let (pk_name, pk_value_ref) = pk_cond;
72 let sk_condition = sk_cond.map(|c| c.into_sk_condition()).transpose()?;
73
74 Ok(KeyCondition {
75 pk_name,
76 pk_value_ref,
77 sk_condition,
78 })
79}
80
81pub fn resolve_values(
83 condition: &KeyCondition,
84 tracker: &TrackedExpressionAttributes,
85) -> Result<ResolvedKeyCondition, String> {
86 let pk_val = tracker.resolve_value(&condition.pk_value_ref)?.clone();
87
88 let sk = if let Some(ref sk_cond) = condition.sk_condition {
89 Some(resolve_sk_condition(sk_cond, tracker)?)
90 } else {
91 None
92 };
93
94 Ok(ResolvedKeyCondition {
95 pk_name: condition.pk_name.clone(),
96 pk_value: pk_val,
97 sk_condition: sk,
98 })
99}
100
101#[derive(Debug)]
103pub struct ResolvedKeyCondition {
104 pub pk_name: String,
105 pub pk_value: AttributeValue,
106 pub sk_condition: Option<ResolvedSortKeyCondition>,
107}
108
109#[derive(Debug)]
110pub enum ResolvedSortKeyCondition {
111 Eq(String, AttributeValue),
112 Lt(String, AttributeValue),
113 Le(String, AttributeValue),
114 Gt(String, AttributeValue),
115 Ge(String, AttributeValue),
116 Between(String, AttributeValue, AttributeValue),
117 BeginsWith(String, AttributeValue),
118}
119
120impl ResolvedSortKeyCondition {
121 pub fn sk_name(&self) -> &str {
122 match self {
123 Self::Eq(n, _)
124 | Self::Lt(n, _)
125 | Self::Le(n, _)
126 | Self::Gt(n, _)
127 | Self::Ge(n, _)
128 | Self::Between(n, _, _)
129 | Self::BeginsWith(n, _) => n,
130 }
131 }
132
133 pub fn to_sql_conditions(&self) -> Vec<(String, String)> {
137 match self {
138 Self::Eq(_, v) => vec![("=".into(), val_to_key_string(v))],
139 Self::Lt(_, v) => vec![("<".into(), val_to_key_string(v))],
140 Self::Le(_, v) => vec![("<=".into(), val_to_key_string(v))],
141 Self::Gt(_, v) => vec![(">".into(), val_to_key_string(v))],
142 Self::Ge(_, v) => vec![(">=".into(), val_to_key_string(v))],
143 Self::Between(_, lo, hi) => vec![
144 (">=".into(), val_to_key_string(lo)),
145 ("<=".into(), val_to_key_string(hi)),
146 ],
147 Self::BeginsWith(_, prefix) => {
148 let prefix_str = val_to_key_string(prefix);
149 let escaped = prefix_str
151 .replace('\\', "\\\\")
152 .replace('%', "\\%")
153 .replace('_', "\\_");
154 vec![("LIKE".into(), format!("{escaped}%"))]
155 }
156 }
157 }
158}
159
160fn val_to_key_string(val: &AttributeValue) -> String {
161 val.to_key_string().unwrap_or_default()
162}
163
164fn resolve_sk_condition(
165 cond: &SortKeyCondition,
166 tracker: &TrackedExpressionAttributes,
167) -> Result<ResolvedSortKeyCondition, String> {
168 match cond {
169 SortKeyCondition::Eq(sk, vr) => {
170 let v = tracker.resolve_value(vr)?.clone();
171 Ok(ResolvedSortKeyCondition::Eq(sk.clone(), v))
172 }
173 SortKeyCondition::Lt(sk, vr) => {
174 let v = tracker.resolve_value(vr)?.clone();
175 Ok(ResolvedSortKeyCondition::Lt(sk.clone(), v))
176 }
177 SortKeyCondition::Le(sk, vr) => {
178 let v = tracker.resolve_value(vr)?.clone();
179 Ok(ResolvedSortKeyCondition::Le(sk.clone(), v))
180 }
181 SortKeyCondition::Gt(sk, vr) => {
182 let v = tracker.resolve_value(vr)?.clone();
183 Ok(ResolvedSortKeyCondition::Gt(sk.clone(), v))
184 }
185 SortKeyCondition::Ge(sk, vr) => {
186 let v = tracker.resolve_value(vr)?.clone();
187 Ok(ResolvedSortKeyCondition::Ge(sk.clone(), v))
188 }
189 SortKeyCondition::Between(sk, lo_ref, hi_ref) => {
190 let lo = tracker.resolve_value(lo_ref)?.clone();
191 let hi = tracker.resolve_value(hi_ref)?.clone();
192 if std::mem::discriminant(&lo) != std::mem::discriminant(&hi) {
194 return Err(format!(
195 "Invalid KeyConditionExpression: The BETWEEN operator requires same data type \
196 for lower and upper bounds; lower bound operand: AttributeValue: {{{}}}, \
197 upper bound operand: AttributeValue: {{{}}}",
198 format_attr_value_short(&lo),
199 format_attr_value_short(&hi)
200 ));
201 }
202 if !between_order_valid(&lo, &hi) {
204 return Err(format!(
205 "Invalid KeyConditionExpression: The BETWEEN operator requires upper bound \
206 to be greater than or equal to lower bound; lower bound operand: \
207 AttributeValue: {{{}}}, upper bound operand: AttributeValue: {{{}}}",
208 format_attr_value_short(&lo),
209 format_attr_value_short(&hi)
210 ));
211 }
212 Ok(ResolvedSortKeyCondition::Between(sk.clone(), lo, hi))
213 }
214 SortKeyCondition::BeginsWith(sk, vr) => {
215 let v = tracker.resolve_value(vr)?.clone();
216 Ok(ResolvedSortKeyCondition::BeginsWith(sk.clone(), v))
217 }
218 }
219}
220
221#[derive(Debug)]
226enum ParsedCond {
227 Eq(String, String), Lt(String, String),
229 Le(String, String),
230 Gt(String, String),
231 Ge(String, String),
232 Between(String, String, String), BeginsWith(String, String), }
235
236impl ParsedCond {
237 fn into_sk_condition(self) -> Result<SortKeyCondition, String> {
238 match self {
239 ParsedCond::Eq(n, v) => Ok(SortKeyCondition::Eq(n, v)),
240 ParsedCond::Lt(n, v) => Ok(SortKeyCondition::Lt(n, v)),
241 ParsedCond::Le(n, v) => Ok(SortKeyCondition::Le(n, v)),
242 ParsedCond::Gt(n, v) => Ok(SortKeyCondition::Gt(n, v)),
243 ParsedCond::Ge(n, v) => Ok(SortKeyCondition::Ge(n, v)),
244 ParsedCond::Between(n, lo, hi) => Ok(SortKeyCondition::Between(n, lo, hi)),
245 ParsedCond::BeginsWith(n, v) => Ok(SortKeyCondition::BeginsWith(n, v)),
246 }
247 }
248}
249
250fn parse_single_condition(
251 stream: &mut TokenStream,
252 tracker: &TrackedExpressionAttributes,
253) -> Result<ParsedCond, String> {
254 if let Some(Token::Identifier(name)) = stream.peek() {
256 if name.to_lowercase() == "begins_with" {
257 stream.next();
258 stream.expect(&Token::LParen)?;
259 let path = parse_raw_path(stream)?;
260 let attr_name = resolve_path_to_name(&path, tracker)?;
261 stream.expect(&Token::Comma)?;
262 let val_ref = expect_value_ref(stream)?;
263 stream.expect(&Token::RParen)?;
264 return Ok(ParsedCond::BeginsWith(attr_name, val_ref));
265 }
266 }
267
268 let path = parse_raw_path(stream)?;
270 let attr_name = resolve_path_to_name(&path, tracker)?;
271
272 match stream.next() {
273 Some(Token::Eq) => {
274 let val_ref = expect_value_ref(stream)?;
275 Ok(ParsedCond::Eq(attr_name, val_ref))
276 }
277 Some(Token::Lt) => {
278 let val_ref = expect_value_ref(stream)?;
279 Ok(ParsedCond::Lt(attr_name, val_ref))
280 }
281 Some(Token::Le) => {
282 let val_ref = expect_value_ref(stream)?;
283 Ok(ParsedCond::Le(attr_name, val_ref))
284 }
285 Some(Token::Gt) => {
286 let val_ref = expect_value_ref(stream)?;
287 Ok(ParsedCond::Gt(attr_name, val_ref))
288 }
289 Some(Token::Ge) => {
290 let val_ref = expect_value_ref(stream)?;
291 Ok(ParsedCond::Ge(attr_name, val_ref))
292 }
293 Some(Token::Between) => {
294 let lo_ref = expect_value_ref(stream)?;
295 stream.expect(&Token::And)?;
296 let hi_ref = expect_value_ref(stream)?;
297 Ok(ParsedCond::Between(attr_name, lo_ref, hi_ref))
298 }
299 Some(t) => Err(format!(
300 "Unexpected operator in KeyConditionExpression: {t}"
301 )),
302 None => Err("Unexpected end of KeyConditionExpression".to_string()),
303 }
304}
305
306fn resolve_path_to_name(
307 path: &[PathElement],
308 tracker: &TrackedExpressionAttributes,
309) -> Result<String, String> {
310 if path.len() != 1 {
311 return Err("KeyConditionExpression only supports top-level attributes".to_string());
312 }
313 match &path[0] {
314 PathElement::Attribute(name) => {
315 if name.starts_with('#') {
316 tracker.resolve_name(name)
317 } else {
318 Ok(name.clone())
319 }
320 }
321 PathElement::Index(_) => Err("KeyConditionExpression cannot use index paths".to_string()),
322 }
323}
324
325fn format_attr_value_short(val: &AttributeValue) -> String {
327 match val {
328 AttributeValue::S(s) => format!("S:{s}"),
329 AttributeValue::N(n) => format!("N:{n}"),
330 AttributeValue::B(b) => {
331 use base64::Engine;
332 let encoded = base64::engine::general_purpose::STANDARD.encode(b);
333 format!("B:{encoded}")
334 }
335 AttributeValue::BOOL(b) => format!("BOOL:{b}"),
336 AttributeValue::NULL(_) => "NULL:true".to_string(),
337 AttributeValue::SS(set) => format!("SS:{:?}", set),
338 AttributeValue::NS(set) => format!("NS:{:?}", set),
339 AttributeValue::BS(_) => "BS:[...]".to_string(),
340 AttributeValue::L(_) => "L:[...]".to_string(),
341 AttributeValue::M(_) => "M:{...}".to_string(),
342 }
343}
344
345fn between_order_valid(lo: &AttributeValue, hi: &AttributeValue) -> bool {
347 match (lo, hi) {
348 (AttributeValue::S(a), AttributeValue::S(b)) => a <= b,
349 (AttributeValue::N(a), AttributeValue::N(b)) => {
350 let a_f = a.parse::<f64>().unwrap_or(0.0);
351 let b_f = b.parse::<f64>().unwrap_or(0.0);
352 a_f <= b_f
353 }
354 (AttributeValue::B(a), AttributeValue::B(b)) => a <= b,
355 _ => true,
356 }
357}
358
359fn expect_value_ref(stream: &mut TokenStream) -> Result<String, String> {
360 match stream.next() {
361 Some(Token::ValueRef(name)) => Ok(name.clone()),
362 Some(t) => Err(format!("Expected value reference (:name), got {t}")),
363 None => Err("Expected value reference, got end of expression".to_string()),
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use std::collections::HashMap;
371
372 fn make_tracker<'a>(
373 names: &'a Option<HashMap<String, String>>,
374 values: &'a Option<HashMap<String, AttributeValue>>,
375 ) -> TrackedExpressionAttributes<'a> {
376 TrackedExpressionAttributes::new(names, values)
377 }
378
379 #[test]
380 fn test_pk_only() {
381 let no_names = None;
382 let no_values = None;
383 let tracker = make_tracker(&no_names, &no_values);
384 let kc = parse("pk = :pk", &tracker).unwrap();
385 assert_eq!(kc.pk_name, "pk");
386 assert_eq!(kc.pk_value_ref, ":pk");
387 assert!(kc.sk_condition.is_none());
388 }
389
390 #[test]
391 fn test_pk_and_sk_eq() {
392 let no_names = None;
393 let no_values = None;
394 let tracker = make_tracker(&no_names, &no_values);
395 let kc = parse("pk = :pk AND sk = :sk", &tracker).unwrap();
396 assert_eq!(kc.pk_name, "pk");
397 assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Eq(_, _))));
398 }
399
400 #[test]
401 fn test_pk_and_sk_between() {
402 let no_names = None;
403 let no_values = None;
404 let tracker = make_tracker(&no_names, &no_values);
405 let kc = parse("pk = :pk AND sk BETWEEN :lo AND :hi", &tracker).unwrap();
406 assert!(matches!(
407 kc.sk_condition,
408 Some(SortKeyCondition::Between(_, _, _))
409 ));
410 }
411
412 #[test]
413 fn test_pk_and_begins_with() {
414 let no_names = None;
415 let no_values = None;
416 let tracker = make_tracker(&no_names, &no_values);
417 let kc = parse("pk = :pk AND begins_with(sk, :prefix)", &tracker).unwrap();
418 assert!(matches!(
419 kc.sk_condition,
420 Some(SortKeyCondition::BeginsWith(_, _))
421 ));
422 }
423
424 #[test]
425 fn test_with_attribute_names() {
426 let an = Some(HashMap::from([
427 ("#pk".to_string(), "partitionKey".to_string()),
428 ("#sk".to_string(), "sortKey".to_string()),
429 ]));
430 let no_values = None;
431 let tracker = make_tracker(&an, &no_values);
432 let kc = parse("#pk = :pk AND #sk > :sk", &tracker).unwrap();
433 assert_eq!(kc.pk_name, "partitionKey");
434 assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Gt(ref n, _)) if n == "sortKey"));
435 }
436
437 #[test]
438 fn test_resolve_values() {
439 let no_names = None;
440 let no_values = None;
441 let parse_tracker = make_tracker(&no_names, &no_values);
442 let kc = parse("pk = :pk AND sk >= :sk", &parse_tracker).unwrap();
443 let av = Some(HashMap::from([
444 (":pk".to_string(), AttributeValue::S("user#1".into())),
445 (":sk".to_string(), AttributeValue::S("2024-01-01".into())),
446 ]));
447 let resolve_tracker = make_tracker(&no_names, &av);
448 let resolved = resolve_values(&kc, &resolve_tracker).unwrap();
449 assert_eq!(resolved.pk_value, AttributeValue::S("user#1".into()));
450 assert!(matches!(
451 resolved.sk_condition,
452 Some(ResolvedSortKeyCondition::Ge(_, _))
453 ));
454 }
455
456 #[test]
457 fn test_sk_comparisons() {
458 let no_names = None;
459 let no_values = None;
460 for (op, variant) in [("<", "Lt"), ("<=", "Le"), (">", "Gt"), (">=", "Ge")] {
461 let tracker = make_tracker(&no_names, &no_values);
462 let kc = parse(&format!("pk = :pk AND sk {op} :sk"), &tracker).unwrap();
463 let sk = kc.sk_condition.unwrap();
464 let name = match &sk {
465 SortKeyCondition::Lt(n, _) => format!("Lt:{n}"),
466 SortKeyCondition::Le(n, _) => format!("Le:{n}"),
467 SortKeyCondition::Gt(n, _) => format!("Gt:{n}"),
468 SortKeyCondition::Ge(n, _) => format!("Ge:{n}"),
469 _ => "other".to_string(),
470 };
471 assert!(name.starts_with(variant), "Expected {variant}, got {name}");
472 }
473 }
474}