Skip to main content

dynoxide/expressions/
key_condition.rs

1//! KeyConditionExpression parsing.
2//!
3//! KeyConditionExpression supports: `pk = :val [AND sk_condition]`
4//! Sort key conditions: `=`, `<`, `<=`, `>`, `>=`, `BETWEEN ... AND ...`, `begins_with(sk, :prefix)`
5
6use crate::expressions::condition::parse_raw_path;
7use crate::expressions::tokenizer::{Token, TokenSpan, TokenStream, tokenize};
8use crate::expressions::{PathElement, TrackedExpressionAttributes};
9use crate::types::AttributeValue;
10
11/// Parsed key condition.
12#[derive(Debug)]
13pub struct KeyCondition {
14    /// Partition key attribute name (resolved).
15    pub pk_name: String,
16    /// Partition key value reference (e.g. `:pk`).
17    pub pk_value_ref: String,
18    /// Optional sort key condition.
19    pub sk_condition: Option<SortKeyCondition>,
20}
21
22/// Sort key condition variants.
23#[derive(Debug)]
24pub enum SortKeyCondition {
25    Eq(String, String), // (sk_name, value_ref)
26    Lt(String, String),
27    Le(String, String),
28    Gt(String, String),
29    Ge(String, String),
30    Between(String, String, String), // (sk_name, lo_ref, hi_ref)
31    BeginsWith(String, String),      // (sk_name, prefix_ref)
32}
33
34/// Parse a KeyConditionExpression string, tracking attribute name usage.
35///
36/// Supports optional parentheses around individual conditions and around
37/// the entire expression, matching DynamoDB behavior.
38pub fn parse(expr: &str, tracker: &TrackedExpressionAttributes) -> Result<KeyCondition, String> {
39    let tokens = tokenize(expr).map_err(|e| format!("Invalid KeyConditionExpression: {e}"))?;
40    let tokens = strip_outer_parens(tokens);
41    let mut stream = TokenStream::new(tokens);
42
43    let cond1 = parse_single_condition(&mut stream, tracker)?;
44
45    let (pk_cond, sk_cond) = if matches!(stream.peek(), Some(Token::And)) {
46        stream.next();
47        let cond2 = parse_single_condition(&mut stream, tracker)?;
48        match (cond1, cond2) {
49            (ParsedCond::Eq(n1, v1), c2) => ((n1, v1), Some(c2)),
50            (c1, ParsedCond::Eq(n2, v2)) => ((n2, v2), Some(c1)),
51            _ => {
52                return Err(
53                    "Invalid KeyConditionExpression: partition key must use equality".to_string(),
54                );
55            }
56        }
57    } else {
58        match cond1 {
59            ParsedCond::Eq(name, val_ref) => ((name, val_ref), None),
60            _ => {
61                return Err(
62                    "Invalid KeyConditionExpression: partition key must use equality".to_string(),
63                );
64            }
65        }
66    };
67
68    if !stream.at_end() {
69        return Err(format!(
70            "Unexpected token in KeyConditionExpression: {}",
71            stream.peek().unwrap()
72        ));
73    }
74
75    let (pk_name, pk_value_ref) = pk_cond;
76    let sk_condition = sk_cond.map(|c| c.into_sk_condition()).transpose()?;
77
78    Ok(KeyCondition {
79        pk_name,
80        pk_value_ref,
81        sk_condition,
82    })
83}
84
85/// Resolve the actual attribute values from the parsed key condition, tracking usage.
86pub fn resolve_values(
87    condition: &KeyCondition,
88    tracker: &TrackedExpressionAttributes,
89) -> Result<ResolvedKeyCondition, String> {
90    let pk_val = tracker.resolve_value(&condition.pk_value_ref)?.clone();
91
92    let sk = if let Some(ref sk_cond) = condition.sk_condition {
93        Some(resolve_sk_condition(sk_cond, tracker)?)
94    } else {
95        None
96    };
97
98    Ok(ResolvedKeyCondition {
99        pk_name: condition.pk_name.clone(),
100        pk_value: pk_val,
101        sk_condition: sk,
102    })
103}
104
105/// Resolved key condition with actual values.
106#[derive(Debug)]
107pub struct ResolvedKeyCondition {
108    pub pk_name: String,
109    pub pk_value: AttributeValue,
110    pub sk_condition: Option<ResolvedSortKeyCondition>,
111}
112
113#[derive(Debug)]
114pub enum ResolvedSortKeyCondition {
115    Eq(String, AttributeValue),
116    Lt(String, AttributeValue),
117    Le(String, AttributeValue),
118    Gt(String, AttributeValue),
119    Ge(String, AttributeValue),
120    Between(String, AttributeValue, AttributeValue),
121    BeginsWith(String, AttributeValue),
122}
123
124impl ResolvedSortKeyCondition {
125    pub fn sk_name(&self) -> &str {
126        match self {
127            Self::Eq(n, _)
128            | Self::Lt(n, _)
129            | Self::Le(n, _)
130            | Self::Gt(n, _)
131            | Self::Ge(n, _)
132            | Self::Between(n, _, _)
133            | Self::BeginsWith(n, _) => n,
134        }
135    }
136
137    /// Convert to SQL WHERE clause components for sk column.
138    /// Returns (operator, value_string) pairs.
139    /// For BETWEEN, returns two conditions.
140    pub fn to_sql_conditions(&self) -> Vec<(String, String)> {
141        match self {
142            Self::Eq(_, v) => vec![("=".into(), val_to_key_string(v))],
143            Self::Lt(_, v) => vec![("<".into(), val_to_key_string(v))],
144            Self::Le(_, v) => vec![("<=".into(), val_to_key_string(v))],
145            Self::Gt(_, v) => vec![(">".into(), val_to_key_string(v))],
146            Self::Ge(_, v) => vec![(">=".into(), val_to_key_string(v))],
147            Self::Between(_, lo, hi) => vec![
148                (">=".into(), val_to_key_string(lo)),
149                ("<=".into(), val_to_key_string(hi)),
150            ],
151            Self::BeginsWith(_, prefix) => {
152                let prefix_str = val_to_key_string(prefix);
153                // Escape LIKE wildcards in the prefix value before appending %
154                let escaped = prefix_str
155                    .replace('\\', "\\\\")
156                    .replace('%', "\\%")
157                    .replace('_', "\\_");
158                vec![("LIKE".into(), format!("{escaped}%"))]
159            }
160        }
161    }
162}
163
164fn val_to_key_string(val: &AttributeValue) -> String {
165    val.to_key_string().unwrap_or_default()
166}
167
168fn resolve_sk_condition(
169    cond: &SortKeyCondition,
170    tracker: &TrackedExpressionAttributes,
171) -> Result<ResolvedSortKeyCondition, String> {
172    match cond {
173        SortKeyCondition::Eq(sk, vr) => {
174            let v = tracker.resolve_value(vr)?.clone();
175            Ok(ResolvedSortKeyCondition::Eq(sk.clone(), v))
176        }
177        SortKeyCondition::Lt(sk, vr) => {
178            let v = tracker.resolve_value(vr)?.clone();
179            Ok(ResolvedSortKeyCondition::Lt(sk.clone(), v))
180        }
181        SortKeyCondition::Le(sk, vr) => {
182            let v = tracker.resolve_value(vr)?.clone();
183            Ok(ResolvedSortKeyCondition::Le(sk.clone(), v))
184        }
185        SortKeyCondition::Gt(sk, vr) => {
186            let v = tracker.resolve_value(vr)?.clone();
187            Ok(ResolvedSortKeyCondition::Gt(sk.clone(), v))
188        }
189        SortKeyCondition::Ge(sk, vr) => {
190            let v = tracker.resolve_value(vr)?.clone();
191            Ok(ResolvedSortKeyCondition::Ge(sk.clone(), v))
192        }
193        SortKeyCondition::Between(sk, lo_ref, hi_ref) => {
194            let lo = tracker.resolve_value(lo_ref)?.clone();
195            let hi = tracker.resolve_value(hi_ref)?.clone();
196            // Validate same type
197            if std::mem::discriminant(&lo) != std::mem::discriminant(&hi) {
198                return Err(format!(
199                    "Invalid KeyConditionExpression: The BETWEEN operator requires same data type \
200                     for lower and upper bounds; lower bound operand: AttributeValue: {{{}}}, \
201                     upper bound operand: AttributeValue: {{{}}}",
202                    format_attr_value_short(&lo),
203                    format_attr_value_short(&hi)
204                ));
205            }
206            // Validate ordering (upper >= lower)
207            if !between_order_valid(&lo, &hi) {
208                return Err(format!(
209                    "Invalid KeyConditionExpression: The BETWEEN operator requires upper bound \
210                     to be greater than or equal to lower bound; lower bound operand: \
211                     AttributeValue: {{{}}}, upper bound operand: AttributeValue: {{{}}}",
212                    format_attr_value_short(&lo),
213                    format_attr_value_short(&hi)
214                ));
215            }
216            Ok(ResolvedSortKeyCondition::Between(sk.clone(), lo, hi))
217        }
218        SortKeyCondition::BeginsWith(sk, vr) => {
219            let v = tracker.resolve_value(vr)?.clone();
220            Ok(ResolvedSortKeyCondition::BeginsWith(sk.clone(), v))
221        }
222    }
223}
224
225// ---------------------------------------------------------------------------
226// Internal parsing helpers
227// ---------------------------------------------------------------------------
228
229#[derive(Debug)]
230enum ParsedCond {
231    Eq(String, String), // (attr_name, value_ref)
232    Lt(String, String),
233    Le(String, String),
234    Gt(String, String),
235    Ge(String, String),
236    Between(String, String, String), // (attr_name, lo_ref, hi_ref)
237    BeginsWith(String, String),      // (attr_name, prefix_ref)
238}
239
240impl ParsedCond {
241    fn into_sk_condition(self) -> Result<SortKeyCondition, String> {
242        match self {
243            ParsedCond::Eq(n, v) => Ok(SortKeyCondition::Eq(n, v)),
244            ParsedCond::Lt(n, v) => Ok(SortKeyCondition::Lt(n, v)),
245            ParsedCond::Le(n, v) => Ok(SortKeyCondition::Le(n, v)),
246            ParsedCond::Gt(n, v) => Ok(SortKeyCondition::Gt(n, v)),
247            ParsedCond::Ge(n, v) => Ok(SortKeyCondition::Ge(n, v)),
248            ParsedCond::Between(n, lo, hi) => Ok(SortKeyCondition::Between(n, lo, hi)),
249            ParsedCond::BeginsWith(n, v) => Ok(SortKeyCondition::BeginsWith(n, v)),
250        }
251    }
252}
253
254/// Strip balanced outer parentheses from a token list.
255/// `(pk = :pk AND sk = :sk)` → `pk = :pk AND sk = :sk`
256/// `((pk = :pk))` → `pk = :pk` (applied repeatedly)
257/// `(pk = :pk) AND (sk = :sk)` → unchanged (closing paren is not at the end)
258fn strip_outer_parens(mut tokens: Vec<(Token, TokenSpan)>) -> Vec<(Token, TokenSpan)> {
259    loop {
260        if tokens.len() < 2 {
261            break;
262        }
263        if !matches!(tokens.first().map(|(t, _)| t), Some(Token::LParen)) {
264            break;
265        }
266        // Walk forward, tracking paren depth, to see if the opening paren's
267        // match is the very last token.
268        let mut depth = 0;
269        let mut close_pos = None;
270        for (i, (tok, _)) in tokens.iter().enumerate() {
271            match tok {
272                Token::LParen => depth += 1,
273                Token::RParen => {
274                    depth -= 1;
275                    if depth == 0 {
276                        close_pos = Some(i);
277                        break;
278                    }
279                }
280                _ => {}
281            }
282        }
283        if close_pos == Some(tokens.len() - 1) {
284            // The outermost parens wrap the entire expression — strip them.
285            tokens.remove(tokens.len() - 1);
286            tokens.remove(0);
287        } else {
288            break;
289        }
290    }
291    tokens
292}
293
294fn parse_single_condition(
295    stream: &mut TokenStream,
296    tracker: &TrackedExpressionAttributes,
297) -> Result<ParsedCond, String> {
298    // Count and skip optional wrapping parentheses: `(pk = :pk)`, `((pk = :pk))`
299    let mut parens = 0;
300    while matches!(stream.peek(), Some(Token::LParen)) {
301        stream.next();
302        parens += 1;
303    }
304
305    // Check for begins_with function
306    if let Some(Token::Identifier(name)) = stream.peek() {
307        if name.to_lowercase() == "begins_with" {
308            stream.next();
309            stream.expect(&Token::LParen)?;
310            let path = parse_raw_path(stream)?;
311            let attr_name = resolve_path_to_name(&path, tracker)?;
312            stream.expect(&Token::Comma)?;
313            let val_ref = expect_value_ref(stream)?;
314            stream.expect(&Token::RParen)?;
315            consume_close_parens(stream, parens)?;
316            return Ok(ParsedCond::BeginsWith(attr_name, val_ref));
317        }
318    }
319
320    // attr op :val
321    let path = parse_raw_path(stream)?;
322    let attr_name = resolve_path_to_name(&path, tracker)?;
323
324    let result = match stream.next() {
325        Some(Token::Eq) => {
326            let val_ref = expect_value_ref(stream)?;
327            Ok(ParsedCond::Eq(attr_name, val_ref))
328        }
329        Some(Token::Lt) => {
330            let val_ref = expect_value_ref(stream)?;
331            Ok(ParsedCond::Lt(attr_name, val_ref))
332        }
333        Some(Token::Le) => {
334            let val_ref = expect_value_ref(stream)?;
335            Ok(ParsedCond::Le(attr_name, val_ref))
336        }
337        Some(Token::Gt) => {
338            let val_ref = expect_value_ref(stream)?;
339            Ok(ParsedCond::Gt(attr_name, val_ref))
340        }
341        Some(Token::Ge) => {
342            let val_ref = expect_value_ref(stream)?;
343            Ok(ParsedCond::Ge(attr_name, val_ref))
344        }
345        Some(Token::Between) => {
346            let lo_ref = expect_value_ref(stream)?;
347            stream.expect(&Token::And)?;
348            let hi_ref = expect_value_ref(stream)?;
349            Ok(ParsedCond::Between(attr_name, lo_ref, hi_ref))
350        }
351        Some(t) => Err(format!(
352            "Unexpected operator in KeyConditionExpression: {t}"
353        )),
354        None => Err("Unexpected end of KeyConditionExpression".to_string()),
355    };
356
357    consume_close_parens(stream, parens)?;
358    result
359}
360
361/// Consume exactly `count` closing parentheses from the stream.
362fn consume_close_parens(stream: &mut TokenStream, count: usize) -> Result<(), String> {
363    for _ in 0..count {
364        match stream.next() {
365            Some(Token::RParen) => {}
366            Some(t) => {
367                return Err(format!(
368                    "Expected closing parenthesis in KeyConditionExpression, got {t}"
369                ));
370            }
371            None => {
372                return Err(
373                    "Unexpected end of KeyConditionExpression, expected closing parenthesis"
374                        .to_string(),
375                );
376            }
377        }
378    }
379    Ok(())
380}
381
382fn resolve_path_to_name(
383    path: &[PathElement],
384    tracker: &TrackedExpressionAttributes,
385) -> Result<String, String> {
386    if path.len() != 1 {
387        return Err("KeyConditionExpression only supports top-level attributes".to_string());
388    }
389    match &path[0] {
390        PathElement::Attribute(name) => {
391            if name.starts_with('#') {
392                tracker.resolve_name(name)
393            } else {
394                Ok(name.clone())
395            }
396        }
397        PathElement::Index(_) => Err("KeyConditionExpression cannot use index paths".to_string()),
398    }
399}
400
401/// Format an attribute value for error messages (DynamoDB short format).
402fn format_attr_value_short(val: &AttributeValue) -> String {
403    match val {
404        AttributeValue::S(s) => format!("S:{s}"),
405        AttributeValue::N(n) => format!("N:{n}"),
406        AttributeValue::B(b) => {
407            use base64::Engine;
408            let encoded = base64::engine::general_purpose::STANDARD.encode(b);
409            format!("B:{encoded}")
410        }
411        AttributeValue::BOOL(b) => format!("BOOL:{b}"),
412        AttributeValue::NULL(_) => "NULL:true".to_string(),
413        AttributeValue::SS(set) => format!("SS:{:?}", set),
414        AttributeValue::NS(set) => format!("NS:{:?}", set),
415        AttributeValue::BS(_) => "BS:[...]".to_string(),
416        AttributeValue::L(_) => "L:[...]".to_string(),
417        AttributeValue::M(_) => "M:{...}".to_string(),
418    }
419}
420
421/// Check if BETWEEN bounds are in valid order (lo <= hi).
422fn between_order_valid(lo: &AttributeValue, hi: &AttributeValue) -> bool {
423    match (lo, hi) {
424        (AttributeValue::S(a), AttributeValue::S(b)) => a <= b,
425        (AttributeValue::N(a), AttributeValue::N(b)) => {
426            let a_f = a.parse::<f64>().unwrap_or(0.0);
427            let b_f = b.parse::<f64>().unwrap_or(0.0);
428            a_f <= b_f
429        }
430        (AttributeValue::B(a), AttributeValue::B(b)) => a <= b,
431        _ => true,
432    }
433}
434
435fn expect_value_ref(stream: &mut TokenStream) -> Result<String, String> {
436    match stream.next() {
437        Some(Token::ValueRef(name)) => Ok(name.clone()),
438        Some(t) => Err(format!("Expected value reference (:name), got {t}")),
439        None => Err("Expected value reference, got end of expression".to_string()),
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446    use std::collections::HashMap;
447
448    fn make_tracker<'a>(
449        names: &'a Option<HashMap<String, String>>,
450        values: &'a Option<HashMap<String, AttributeValue>>,
451    ) -> TrackedExpressionAttributes<'a> {
452        TrackedExpressionAttributes::new(names, values)
453    }
454
455    #[test]
456    fn test_pk_only() {
457        let no_names = None;
458        let no_values = None;
459        let tracker = make_tracker(&no_names, &no_values);
460        let kc = parse("pk = :pk", &tracker).unwrap();
461        assert_eq!(kc.pk_name, "pk");
462        assert_eq!(kc.pk_value_ref, ":pk");
463        assert!(kc.sk_condition.is_none());
464    }
465
466    #[test]
467    fn test_pk_and_sk_eq() {
468        let no_names = None;
469        let no_values = None;
470        let tracker = make_tracker(&no_names, &no_values);
471        let kc = parse("pk = :pk AND sk = :sk", &tracker).unwrap();
472        assert_eq!(kc.pk_name, "pk");
473        assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Eq(_, _))));
474    }
475
476    #[test]
477    fn test_pk_and_sk_between() {
478        let no_names = None;
479        let no_values = None;
480        let tracker = make_tracker(&no_names, &no_values);
481        let kc = parse("pk = :pk AND sk BETWEEN :lo AND :hi", &tracker).unwrap();
482        assert!(matches!(
483            kc.sk_condition,
484            Some(SortKeyCondition::Between(_, _, _))
485        ));
486    }
487
488    #[test]
489    fn test_pk_and_begins_with() {
490        let no_names = None;
491        let no_values = None;
492        let tracker = make_tracker(&no_names, &no_values);
493        let kc = parse("pk = :pk AND begins_with(sk, :prefix)", &tracker).unwrap();
494        assert!(matches!(
495            kc.sk_condition,
496            Some(SortKeyCondition::BeginsWith(_, _))
497        ));
498    }
499
500    #[test]
501    fn test_with_attribute_names() {
502        let an = Some(HashMap::from([
503            ("#pk".to_string(), "partitionKey".to_string()),
504            ("#sk".to_string(), "sortKey".to_string()),
505        ]));
506        let no_values = None;
507        let tracker = make_tracker(&an, &no_values);
508        let kc = parse("#pk = :pk AND #sk > :sk", &tracker).unwrap();
509        assert_eq!(kc.pk_name, "partitionKey");
510        assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Gt(ref n, _)) if n == "sortKey"));
511    }
512
513    #[test]
514    fn test_resolve_values() {
515        let no_names = None;
516        let no_values = None;
517        let parse_tracker = make_tracker(&no_names, &no_values);
518        let kc = parse("pk = :pk AND sk >= :sk", &parse_tracker).unwrap();
519        let av = Some(HashMap::from([
520            (":pk".to_string(), AttributeValue::S("user#1".into())),
521            (":sk".to_string(), AttributeValue::S("2024-01-01".into())),
522        ]));
523        let resolve_tracker = make_tracker(&no_names, &av);
524        let resolved = resolve_values(&kc, &resolve_tracker).unwrap();
525        assert_eq!(resolved.pk_value, AttributeValue::S("user#1".into()));
526        assert!(matches!(
527            resolved.sk_condition,
528            Some(ResolvedSortKeyCondition::Ge(_, _))
529        ));
530    }
531
532    #[test]
533    fn test_parenthesized_conditions() {
534        let no_names = None;
535        let no_values = None;
536
537        // Parens around each condition
538        let tracker = make_tracker(&no_names, &no_values);
539        let kc = parse("(pk = :pk) AND (sk = :sk)", &tracker).unwrap();
540        assert_eq!(kc.pk_name, "pk");
541        assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Eq(_, _))));
542
543        // Parens around entire expression
544        let tracker = make_tracker(&no_names, &no_values);
545        let kc = parse("(pk = :pk AND sk = :sk)", &tracker).unwrap();
546        assert_eq!(kc.pk_name, "pk");
547        assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Eq(_, _))));
548
549        // Nested parens
550        let tracker = make_tracker(&no_names, &no_values);
551        let kc = parse("((pk = :pk)) AND ((sk > :sk))", &tracker).unwrap();
552        assert_eq!(kc.pk_name, "pk");
553        assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Gt(_, _))));
554
555        // Multiple pairs: outer parens wrapping conditions with their own nested parens
556        let tracker = make_tracker(&no_names, &no_values);
557        let kc = parse("(((pk = :pk)) AND ((begins_with(sk, :prefix))))", &tracker).unwrap();
558        assert_eq!(kc.pk_name, "pk");
559        assert!(matches!(
560            kc.sk_condition,
561            Some(SortKeyCondition::BeginsWith(_, _))
562        ));
563
564        // Parens around begins_with
565        let tracker = make_tracker(&no_names, &no_values);
566        let kc = parse("(pk = :pk) AND (begins_with(sk, :prefix))", &tracker).unwrap();
567        assert!(matches!(
568            kc.sk_condition,
569            Some(SortKeyCondition::BeginsWith(_, _))
570        ));
571
572        // Parens with attribute name references
573        let an = Some(HashMap::from([
574            ("#pk".to_string(), "PK".to_string()),
575            ("#sk".to_string(), "SK".to_string()),
576        ]));
577        let tracker = make_tracker(&an, &no_values);
578        let kc = parse("(#pk = :pk) AND (#sk = :sk)", &tracker).unwrap();
579        assert_eq!(kc.pk_name, "PK");
580    }
581
582    #[test]
583    fn test_sk_comparisons() {
584        let no_names = None;
585        let no_values = None;
586        for (op, variant) in [("<", "Lt"), ("<=", "Le"), (">", "Gt"), (">=", "Ge")] {
587            let tracker = make_tracker(&no_names, &no_values);
588            let kc = parse(&format!("pk = :pk AND sk {op} :sk"), &tracker).unwrap();
589            let sk = kc.sk_condition.unwrap();
590            let name = match &sk {
591                SortKeyCondition::Lt(n, _) => format!("Lt:{n}"),
592                SortKeyCondition::Le(n, _) => format!("Le:{n}"),
593                SortKeyCondition::Gt(n, _) => format!("Gt:{n}"),
594                SortKeyCondition::Ge(n, _) => format!("Ge:{n}"),
595                _ => "other".to_string(),
596            };
597            assert!(name.starts_with(variant), "Expected {variant}, got {name}");
598        }
599    }
600}