Skip to main content

dynoxide/expressions/
key_condition.rs

1//! KeyConditionExpression parsing.
2//!
3//! KeyConditionExpression supports: `pk = :val [AND sk_condition]`
4//! Sort key conditions: `=`, `<`, `<=`, `>`, `>=`, `BETWEEN ... AND ...`, `begins_with(sk, :prefix)`
5
6use crate::expressions::condition::parse_raw_path;
7use crate::expressions::tokenizer::{
8    Token, TokenSpan, TokenStream, check_redundant_parens, tokenize,
9};
10use crate::expressions::{PathElement, TrackedExpressionAttributes};
11use crate::types::AttributeValue;
12
13/// Parsed key condition.
14#[derive(Debug)]
15pub struct KeyCondition {
16    /// Partition key attribute name (resolved).
17    pub pk_name: String,
18    /// Partition key value reference (e.g. `:pk`).
19    pub pk_value_ref: String,
20    /// Optional sort key condition.
21    pub sk_condition: Option<SortKeyCondition>,
22}
23
24/// Sort key condition variants.
25#[derive(Debug)]
26pub enum SortKeyCondition {
27    Eq(String, String), // (sk_name, value_ref)
28    Lt(String, String),
29    Le(String, String),
30    Gt(String, String),
31    Ge(String, String),
32    Between(String, String, String), // (sk_name, lo_ref, hi_ref)
33    BeginsWith(String, String),      // (sk_name, prefix_ref)
34}
35
36/// Parse a KeyConditionExpression string, tracking attribute name usage.
37///
38/// Supports optional parentheses around individual conditions and around
39/// the entire expression, matching DynamoDB behavior.
40pub fn parse(expr: &str, tracker: &TrackedExpressionAttributes) -> Result<KeyCondition, String> {
41    let tokens = tokenize(expr).map_err(|e| format!("Invalid KeyConditionExpression: {e}"))?;
42    // Reject redundant parens before stripping outer ones (strip_outer_parens
43    // would otherwise silently accept `((pk = :pk))`).
44    check_redundant_parens(&tokens).map_err(|e| format!("Invalid KeyConditionExpression: {e}"))?;
45    let tokens = strip_outer_parens(tokens);
46    let mut stream = TokenStream::new(tokens);
47
48    let cond1 = parse_single_condition(&mut stream, tracker)?;
49
50    let (pk_cond, sk_cond) = if matches!(stream.peek(), Some(Token::And)) {
51        stream.next();
52        let cond2 = parse_single_condition(&mut stream, tracker)?;
53        match (cond1, cond2) {
54            (ParsedCond::Eq(n1, v1), c2) => ((n1, v1), Some(c2)),
55            (c1, ParsedCond::Eq(n2, v2)) => ((n2, v2), Some(c1)),
56            _ => {
57                return Err(
58                    "Invalid KeyConditionExpression: partition key must use equality".to_string(),
59                );
60            }
61        }
62    } else {
63        match cond1 {
64            ParsedCond::Eq(name, val_ref) => ((name, val_ref), None),
65            _ => {
66                return Err(
67                    "Invalid KeyConditionExpression: partition key must use equality".to_string(),
68                );
69            }
70        }
71    };
72
73    if !stream.at_end() {
74        return Err(format!(
75            "Unexpected token in KeyConditionExpression: {}",
76            stream.peek().unwrap()
77        ));
78    }
79
80    let (pk_name, pk_value_ref) = pk_cond;
81    let sk_condition = sk_cond.map(|c| c.into_sk_condition()).transpose()?;
82
83    Ok(KeyCondition {
84        pk_name,
85        pk_value_ref,
86        sk_condition,
87    })
88}
89
90/// Resolve the actual attribute values from the parsed key condition, tracking usage.
91pub fn resolve_values(
92    condition: &KeyCondition,
93    tracker: &TrackedExpressionAttributes,
94) -> Result<ResolvedKeyCondition, String> {
95    let pk_val = tracker.resolve_value(&condition.pk_value_ref)?.clone();
96
97    let sk = if let Some(ref sk_cond) = condition.sk_condition {
98        Some(resolve_sk_condition(sk_cond, tracker)?)
99    } else {
100        None
101    };
102
103    Ok(ResolvedKeyCondition {
104        pk_name: condition.pk_name.clone(),
105        pk_value: pk_val,
106        sk_condition: sk,
107    })
108}
109
110/// Resolved key condition with actual values.
111#[derive(Debug)]
112pub struct ResolvedKeyCondition {
113    pub pk_name: String,
114    pub pk_value: AttributeValue,
115    pub sk_condition: Option<ResolvedSortKeyCondition>,
116}
117
118#[derive(Debug)]
119pub enum ResolvedSortKeyCondition {
120    Eq(String, AttributeValue),
121    Lt(String, AttributeValue),
122    Le(String, AttributeValue),
123    Gt(String, AttributeValue),
124    Ge(String, AttributeValue),
125    Between(String, AttributeValue, AttributeValue),
126    BeginsWith(String, AttributeValue),
127}
128
129impl ResolvedSortKeyCondition {
130    pub fn sk_name(&self) -> &str {
131        match self {
132            Self::Eq(n, _)
133            | Self::Lt(n, _)
134            | Self::Le(n, _)
135            | Self::Gt(n, _)
136            | Self::Ge(n, _)
137            | Self::Between(n, _, _)
138            | Self::BeginsWith(n, _) => n,
139        }
140    }
141
142    /// Convert to SQL WHERE clause components for sk column.
143    /// Returns (operator, value_string) pairs.
144    /// For BETWEEN, returns two conditions.
145    pub fn to_sql_conditions(&self) -> Vec<(String, String)> {
146        match self {
147            Self::Eq(_, v) => vec![("=".into(), val_to_key_string(v))],
148            Self::Lt(_, v) => vec![("<".into(), val_to_key_string(v))],
149            Self::Le(_, v) => vec![("<=".into(), val_to_key_string(v))],
150            Self::Gt(_, v) => vec![(">".into(), val_to_key_string(v))],
151            Self::Ge(_, v) => vec![(">=".into(), val_to_key_string(v))],
152            Self::Between(_, lo, hi) => vec![
153                (">=".into(), val_to_key_string(lo)),
154                ("<=".into(), val_to_key_string(hi)),
155            ],
156            Self::BeginsWith(_, prefix) => {
157                let prefix_str = val_to_key_string(prefix);
158                // Escape LIKE wildcards in the prefix value before appending %
159                let escaped = prefix_str
160                    .replace('\\', "\\\\")
161                    .replace('%', "\\%")
162                    .replace('_', "\\_");
163                vec![("LIKE".into(), format!("{escaped}%"))]
164            }
165        }
166    }
167}
168
169fn val_to_key_string(val: &AttributeValue) -> String {
170    val.to_key_string().unwrap_or_default()
171}
172
173fn resolve_sk_condition(
174    cond: &SortKeyCondition,
175    tracker: &TrackedExpressionAttributes,
176) -> Result<ResolvedSortKeyCondition, String> {
177    match cond {
178        SortKeyCondition::Eq(sk, vr) => {
179            let v = tracker.resolve_value(vr)?.clone();
180            Ok(ResolvedSortKeyCondition::Eq(sk.clone(), v))
181        }
182        SortKeyCondition::Lt(sk, vr) => {
183            let v = tracker.resolve_value(vr)?.clone();
184            Ok(ResolvedSortKeyCondition::Lt(sk.clone(), v))
185        }
186        SortKeyCondition::Le(sk, vr) => {
187            let v = tracker.resolve_value(vr)?.clone();
188            Ok(ResolvedSortKeyCondition::Le(sk.clone(), v))
189        }
190        SortKeyCondition::Gt(sk, vr) => {
191            let v = tracker.resolve_value(vr)?.clone();
192            Ok(ResolvedSortKeyCondition::Gt(sk.clone(), v))
193        }
194        SortKeyCondition::Ge(sk, vr) => {
195            let v = tracker.resolve_value(vr)?.clone();
196            Ok(ResolvedSortKeyCondition::Ge(sk.clone(), v))
197        }
198        SortKeyCondition::Between(sk, lo_ref, hi_ref) => {
199            let lo = tracker.resolve_value(lo_ref)?.clone();
200            let hi = tracker.resolve_value(hi_ref)?.clone();
201            // Validate same type
202            if std::mem::discriminant(&lo) != std::mem::discriminant(&hi) {
203                return Err(format!(
204                    "Invalid KeyConditionExpression: The BETWEEN operator requires same data type \
205                     for lower and upper bounds; lower bound operand: AttributeValue: {{{}}}, \
206                     upper bound operand: AttributeValue: {{{}}}",
207                    format_attr_value_short(&lo),
208                    format_attr_value_short(&hi)
209                ));
210            }
211            // Validate ordering (upper >= lower)
212            if !between_order_valid(&lo, &hi) {
213                return Err(format!(
214                    "Invalid KeyConditionExpression: The BETWEEN operator requires upper bound \
215                     to be greater than or equal to lower bound; lower bound operand: \
216                     AttributeValue: {{{}}}, upper bound operand: AttributeValue: {{{}}}",
217                    format_attr_value_short(&lo),
218                    format_attr_value_short(&hi)
219                ));
220            }
221            Ok(ResolvedSortKeyCondition::Between(sk.clone(), lo, hi))
222        }
223        SortKeyCondition::BeginsWith(sk, vr) => {
224            let v = tracker.resolve_value(vr)?.clone();
225            Ok(ResolvedSortKeyCondition::BeginsWith(sk.clone(), v))
226        }
227    }
228}
229
230// ---------------------------------------------------------------------------
231// Internal parsing helpers
232// ---------------------------------------------------------------------------
233
234#[derive(Debug)]
235enum ParsedCond {
236    Eq(String, String), // (attr_name, value_ref)
237    Lt(String, String),
238    Le(String, String),
239    Gt(String, String),
240    Ge(String, String),
241    Between(String, String, String), // (attr_name, lo_ref, hi_ref)
242    BeginsWith(String, String),      // (attr_name, prefix_ref)
243}
244
245impl ParsedCond {
246    fn into_sk_condition(self) -> Result<SortKeyCondition, String> {
247        match self {
248            ParsedCond::Eq(n, v) => Ok(SortKeyCondition::Eq(n, v)),
249            ParsedCond::Lt(n, v) => Ok(SortKeyCondition::Lt(n, v)),
250            ParsedCond::Le(n, v) => Ok(SortKeyCondition::Le(n, v)),
251            ParsedCond::Gt(n, v) => Ok(SortKeyCondition::Gt(n, v)),
252            ParsedCond::Ge(n, v) => Ok(SortKeyCondition::Ge(n, v)),
253            ParsedCond::Between(n, lo, hi) => Ok(SortKeyCondition::Between(n, lo, hi)),
254            ParsedCond::BeginsWith(n, v) => Ok(SortKeyCondition::BeginsWith(n, v)),
255        }
256    }
257}
258
259/// Strip balanced outer parentheses from a token list.
260/// `(pk = :pk AND sk = :sk)` → `pk = :pk AND sk = :sk`
261/// `((pk = :pk))` → `pk = :pk` (applied repeatedly)
262/// `(pk = :pk) AND (sk = :sk)` → unchanged (closing paren is not at the end)
263fn strip_outer_parens(mut tokens: Vec<(Token, TokenSpan)>) -> Vec<(Token, TokenSpan)> {
264    loop {
265        if tokens.len() < 2 {
266            break;
267        }
268        if !matches!(tokens.first().map(|(t, _)| t), Some(Token::LParen)) {
269            break;
270        }
271        // Walk forward, tracking paren depth, to see if the opening paren's
272        // match is the very last token.
273        let mut depth = 0;
274        let mut close_pos = None;
275        for (i, (tok, _)) in tokens.iter().enumerate() {
276            match tok {
277                Token::LParen => depth += 1,
278                Token::RParen => {
279                    depth -= 1;
280                    if depth == 0 {
281                        close_pos = Some(i);
282                        break;
283                    }
284                }
285                _ => {}
286            }
287        }
288        if close_pos == Some(tokens.len() - 1) {
289            // The outermost parens wrap the entire expression — strip them.
290            tokens.remove(tokens.len() - 1);
291            tokens.remove(0);
292        } else {
293            break;
294        }
295    }
296    tokens
297}
298
299fn parse_single_condition(
300    stream: &mut TokenStream,
301    tracker: &TrackedExpressionAttributes,
302) -> Result<ParsedCond, String> {
303    // Count and skip optional wrapping parentheses: `(pk = :pk)`, `((pk = :pk))`
304    let mut parens = 0;
305    while matches!(stream.peek(), Some(Token::LParen)) {
306        stream.next();
307        parens += 1;
308    }
309
310    // Check for begins_with function
311    if let Some(Token::Identifier(name)) = stream.peek() {
312        if name.to_lowercase() == "begins_with" {
313            stream.next();
314            stream.expect(&Token::LParen)?;
315            let path = parse_raw_path(stream)?;
316            let attr_name = resolve_path_to_name(&path, tracker)?;
317            stream.expect(&Token::Comma)?;
318            let val_ref = expect_value_ref(stream)?;
319            stream.expect(&Token::RParen)?;
320            consume_close_parens(stream, parens)?;
321            return Ok(ParsedCond::BeginsWith(attr_name, val_ref));
322        }
323    }
324
325    // attr op :val
326    let path = parse_raw_path(stream)?;
327    let attr_name = resolve_path_to_name(&path, tracker)?;
328
329    let result = match stream.next() {
330        Some(Token::Eq) => {
331            let val_ref = expect_value_ref(stream)?;
332            Ok(ParsedCond::Eq(attr_name, val_ref))
333        }
334        Some(Token::Lt) => {
335            let val_ref = expect_value_ref(stream)?;
336            Ok(ParsedCond::Lt(attr_name, val_ref))
337        }
338        Some(Token::Le) => {
339            let val_ref = expect_value_ref(stream)?;
340            Ok(ParsedCond::Le(attr_name, val_ref))
341        }
342        Some(Token::Gt) => {
343            let val_ref = expect_value_ref(stream)?;
344            Ok(ParsedCond::Gt(attr_name, val_ref))
345        }
346        Some(Token::Ge) => {
347            let val_ref = expect_value_ref(stream)?;
348            Ok(ParsedCond::Ge(attr_name, val_ref))
349        }
350        Some(Token::Between) => {
351            let lo_ref = expect_value_ref(stream)?;
352            stream.expect(&Token::And)?;
353            let hi_ref = expect_value_ref(stream)?;
354            Ok(ParsedCond::Between(attr_name, lo_ref, hi_ref))
355        }
356        Some(t) => Err(format!(
357            "Unexpected operator in KeyConditionExpression: {t}"
358        )),
359        None => Err("Unexpected end of KeyConditionExpression".to_string()),
360    };
361
362    consume_close_parens(stream, parens)?;
363    result
364}
365
366/// Consume exactly `count` closing parentheses from the stream.
367fn consume_close_parens(stream: &mut TokenStream, count: usize) -> Result<(), String> {
368    for _ in 0..count {
369        match stream.next() {
370            Some(Token::RParen) => {}
371            Some(t) => {
372                return Err(format!(
373                    "Expected closing parenthesis in KeyConditionExpression, got {t}"
374                ));
375            }
376            None => {
377                return Err(
378                    "Unexpected end of KeyConditionExpression, expected closing parenthesis"
379                        .to_string(),
380                );
381            }
382        }
383    }
384    Ok(())
385}
386
387fn resolve_path_to_name(
388    path: &[PathElement],
389    tracker: &TrackedExpressionAttributes,
390) -> Result<String, String> {
391    if path.len() != 1 {
392        return Err("KeyConditionExpression only supports top-level attributes".to_string());
393    }
394    match &path[0] {
395        PathElement::Attribute(name) => {
396            if name.starts_with('#') {
397                tracker.resolve_name(name)
398            } else {
399                Ok(name.clone())
400            }
401        }
402        PathElement::Index(_) => Err("KeyConditionExpression cannot use index paths".to_string()),
403    }
404}
405
406/// Format an attribute value for error messages (DynamoDB short format).
407fn format_attr_value_short(val: &AttributeValue) -> String {
408    match val {
409        AttributeValue::S(s) => format!("S:{s}"),
410        AttributeValue::N(n) => format!("N:{n}"),
411        AttributeValue::B(b) => {
412            use base64::Engine;
413            let encoded = base64::engine::general_purpose::STANDARD.encode(b);
414            format!("B:{encoded}")
415        }
416        AttributeValue::BOOL(b) => format!("BOOL:{b}"),
417        AttributeValue::NULL(_) => "NULL:true".to_string(),
418        AttributeValue::SS(set) => format!("SS:{:?}", set),
419        AttributeValue::NS(set) => format!("NS:{:?}", set),
420        AttributeValue::BS(_) => "BS:[...]".to_string(),
421        AttributeValue::L(_) => "L:[...]".to_string(),
422        AttributeValue::M(_) => "M:{...}".to_string(),
423    }
424}
425
426/// Check if BETWEEN bounds are in valid order (lo <= hi).
427fn between_order_valid(lo: &AttributeValue, hi: &AttributeValue) -> bool {
428    match (lo, hi) {
429        (AttributeValue::S(a), AttributeValue::S(b)) => a <= b,
430        (AttributeValue::N(a), AttributeValue::N(b)) => {
431            let a_f = a.parse::<f64>().unwrap_or(0.0);
432            let b_f = b.parse::<f64>().unwrap_or(0.0);
433            a_f <= b_f
434        }
435        (AttributeValue::B(a), AttributeValue::B(b)) => a <= b,
436        _ => true,
437    }
438}
439
440fn expect_value_ref(stream: &mut TokenStream) -> Result<String, String> {
441    match stream.next() {
442        Some(Token::ValueRef(name)) => Ok(name.clone()),
443        Some(t) => Err(format!("Expected value reference (:name), got {t}")),
444        None => Err("Expected value reference, got end of expression".to_string()),
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use std::collections::HashMap;
452
453    fn make_tracker<'a>(
454        names: &'a Option<HashMap<String, String>>,
455        values: &'a Option<HashMap<String, AttributeValue>>,
456    ) -> TrackedExpressionAttributes<'a> {
457        TrackedExpressionAttributes::new(names, values)
458    }
459
460    #[test]
461    fn test_pk_only() {
462        let no_names = None;
463        let no_values = None;
464        let tracker = make_tracker(&no_names, &no_values);
465        let kc = parse("pk = :pk", &tracker).unwrap();
466        assert_eq!(kc.pk_name, "pk");
467        assert_eq!(kc.pk_value_ref, ":pk");
468        assert!(kc.sk_condition.is_none());
469    }
470
471    #[test]
472    fn test_pk_and_sk_eq() {
473        let no_names = None;
474        let no_values = None;
475        let tracker = make_tracker(&no_names, &no_values);
476        let kc = parse("pk = :pk AND sk = :sk", &tracker).unwrap();
477        assert_eq!(kc.pk_name, "pk");
478        assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Eq(_, _))));
479    }
480
481    #[test]
482    fn test_pk_and_sk_between() {
483        let no_names = None;
484        let no_values = None;
485        let tracker = make_tracker(&no_names, &no_values);
486        let kc = parse("pk = :pk AND sk BETWEEN :lo AND :hi", &tracker).unwrap();
487        assert!(matches!(
488            kc.sk_condition,
489            Some(SortKeyCondition::Between(_, _, _))
490        ));
491    }
492
493    #[test]
494    fn test_pk_and_begins_with() {
495        let no_names = None;
496        let no_values = None;
497        let tracker = make_tracker(&no_names, &no_values);
498        let kc = parse("pk = :pk AND begins_with(sk, :prefix)", &tracker).unwrap();
499        assert!(matches!(
500            kc.sk_condition,
501            Some(SortKeyCondition::BeginsWith(_, _))
502        ));
503    }
504
505    #[test]
506    fn test_with_attribute_names() {
507        let an = Some(HashMap::from([
508            ("#pk".to_string(), "partitionKey".to_string()),
509            ("#sk".to_string(), "sortKey".to_string()),
510        ]));
511        let no_values = None;
512        let tracker = make_tracker(&an, &no_values);
513        let kc = parse("#pk = :pk AND #sk > :sk", &tracker).unwrap();
514        assert_eq!(kc.pk_name, "partitionKey");
515        assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Gt(ref n, _)) if n == "sortKey"));
516    }
517
518    #[test]
519    fn test_resolve_values() {
520        let no_names = None;
521        let no_values = None;
522        let parse_tracker = make_tracker(&no_names, &no_values);
523        let kc = parse("pk = :pk AND sk >= :sk", &parse_tracker).unwrap();
524        let av = Some(HashMap::from([
525            (":pk".to_string(), AttributeValue::S("user#1".into())),
526            (":sk".to_string(), AttributeValue::S("2024-01-01".into())),
527        ]));
528        let resolve_tracker = make_tracker(&no_names, &av);
529        let resolved = resolve_values(&kc, &resolve_tracker).unwrap();
530        assert_eq!(resolved.pk_value, AttributeValue::S("user#1".into()));
531        assert!(matches!(
532            resolved.sk_condition,
533            Some(ResolvedSortKeyCondition::Ge(_, _))
534        ));
535    }
536
537    #[test]
538    fn test_parenthesized_conditions() {
539        let no_names = None;
540        let no_values = None;
541
542        // Parens around each condition
543        let tracker = make_tracker(&no_names, &no_values);
544        let kc = parse("(pk = :pk) AND (sk = :sk)", &tracker).unwrap();
545        assert_eq!(kc.pk_name, "pk");
546        assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Eq(_, _))));
547
548        // Parens around entire expression
549        let tracker = make_tracker(&no_names, &no_values);
550        let kc = parse("(pk = :pk AND sk = :sk)", &tracker).unwrap();
551        assert_eq!(kc.pk_name, "pk");
552        assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Eq(_, _))));
553
554        // Genuinely nested (non-redundant) parens are accepted.
555        let tracker = make_tracker(&no_names, &no_values);
556        let kc = parse("(pk = :pk AND (sk > :sk))", &tracker).unwrap();
557        assert_eq!(kc.pk_name, "pk");
558        assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Gt(_, _))));
559
560        // Redundant parentheses are rejected, matching real DynamoDB. Dynoxide
561        // previously stripped them and silently accepted.
562        let tracker = make_tracker(&no_names, &no_values);
563        let err = parse("((pk = :pk)) AND ((sk > :sk))", &tracker).unwrap_err();
564        assert!(
565            err.contains("redundant parentheses"),
566            "expected redundant-parentheses rejection, got: {err}"
567        );
568
569        // Parens around begins_with
570        let tracker = make_tracker(&no_names, &no_values);
571        let kc = parse("(pk = :pk) AND (begins_with(sk, :prefix))", &tracker).unwrap();
572        assert!(matches!(
573            kc.sk_condition,
574            Some(SortKeyCondition::BeginsWith(_, _))
575        ));
576
577        // Parens with attribute name references
578        let an = Some(HashMap::from([
579            ("#pk".to_string(), "PK".to_string()),
580            ("#sk".to_string(), "SK".to_string()),
581        ]));
582        let tracker = make_tracker(&an, &no_values);
583        let kc = parse("(#pk = :pk) AND (#sk = :sk)", &tracker).unwrap();
584        assert_eq!(kc.pk_name, "PK");
585    }
586
587    #[test]
588    fn test_sk_comparisons() {
589        let no_names = None;
590        let no_values = None;
591        for (op, variant) in [("<", "Lt"), ("<=", "Le"), (">", "Gt"), (">=", "Ge")] {
592            let tracker = make_tracker(&no_names, &no_values);
593            let kc = parse(&format!("pk = :pk AND sk {op} :sk"), &tracker).unwrap();
594            let sk = kc.sk_condition.unwrap();
595            let name = match &sk {
596                SortKeyCondition::Lt(n, _) => format!("Lt:{n}"),
597                SortKeyCondition::Le(n, _) => format!("Le:{n}"),
598                SortKeyCondition::Gt(n, _) => format!("Gt:{n}"),
599                SortKeyCondition::Ge(n, _) => format!("Ge:{n}"),
600                _ => "other".to_string(),
601            };
602            assert!(name.starts_with(variant), "Expected {variant}, got {name}");
603        }
604    }
605}