Skip to main content

dynoxide/expressions/
key_condition.rs

1//! KeyConditionExpression parsing.
2//!
3//! KeyConditionExpression supports: `pk = :val [AND sk_condition]`
4//! Sort key conditions: `=`, `<`, `<=`, `>`, `>=`, `BETWEEN ... AND ...`, `begins_with(sk, :prefix)`
5
6use crate::expressions::condition::parse_raw_path;
7use crate::expressions::tokenizer::{Token, TokenStream, tokenize};
8use crate::expressions::{PathElement, TrackedExpressionAttributes};
9use crate::types::AttributeValue;
10
11/// Parsed key condition.
12#[derive(Debug)]
13pub struct KeyCondition {
14    /// Partition key attribute name (resolved).
15    pub pk_name: String,
16    /// Partition key value reference (e.g. `:pk`).
17    pub pk_value_ref: String,
18    /// Optional sort key condition.
19    pub sk_condition: Option<SortKeyCondition>,
20}
21
22/// Sort key condition variants.
23#[derive(Debug)]
24pub enum SortKeyCondition {
25    Eq(String, String), // (sk_name, value_ref)
26    Lt(String, String),
27    Le(String, String),
28    Gt(String, String),
29    Ge(String, String),
30    Between(String, String, String), // (sk_name, lo_ref, hi_ref)
31    BeginsWith(String, String),      // (sk_name, prefix_ref)
32}
33
34/// Parse a KeyConditionExpression string, tracking attribute name usage.
35pub fn parse(expr: &str, tracker: &TrackedExpressionAttributes) -> Result<KeyCondition, String> {
36    let tokens = tokenize(expr).map_err(|e| format!("Invalid KeyConditionExpression: {e}"))?;
37    let mut stream = TokenStream::new(tokens);
38
39    let cond1 = parse_single_condition(&mut stream, tracker)?;
40
41    let (pk_cond, sk_cond) = if matches!(stream.peek(), Some(Token::And)) {
42        stream.next();
43        let cond2 = parse_single_condition(&mut stream, tracker)?;
44        match (cond1, cond2) {
45            (ParsedCond::Eq(n1, v1), c2) => ((n1, v1), Some(c2)),
46            (c1, ParsedCond::Eq(n2, v2)) => ((n2, v2), Some(c1)),
47            _ => {
48                return Err(
49                    "Invalid KeyConditionExpression: partition key must use equality".to_string(),
50                );
51            }
52        }
53    } else {
54        match cond1 {
55            ParsedCond::Eq(name, val_ref) => ((name, val_ref), None),
56            _ => {
57                return Err(
58                    "Invalid KeyConditionExpression: partition key must use equality".to_string(),
59                );
60            }
61        }
62    };
63
64    if !stream.at_end() {
65        return Err(format!(
66            "Unexpected token in KeyConditionExpression: {}",
67            stream.peek().unwrap()
68        ));
69    }
70
71    let (pk_name, pk_value_ref) = pk_cond;
72    let sk_condition = sk_cond.map(|c| c.into_sk_condition()).transpose()?;
73
74    Ok(KeyCondition {
75        pk_name,
76        pk_value_ref,
77        sk_condition,
78    })
79}
80
81/// Resolve the actual attribute values from the parsed key condition, tracking usage.
82pub fn resolve_values(
83    condition: &KeyCondition,
84    tracker: &TrackedExpressionAttributes,
85) -> Result<ResolvedKeyCondition, String> {
86    let pk_val = tracker.resolve_value(&condition.pk_value_ref)?.clone();
87
88    let sk = if let Some(ref sk_cond) = condition.sk_condition {
89        Some(resolve_sk_condition(sk_cond, tracker)?)
90    } else {
91        None
92    };
93
94    Ok(ResolvedKeyCondition {
95        pk_name: condition.pk_name.clone(),
96        pk_value: pk_val,
97        sk_condition: sk,
98    })
99}
100
101/// Resolved key condition with actual values.
102#[derive(Debug)]
103pub struct ResolvedKeyCondition {
104    pub pk_name: String,
105    pub pk_value: AttributeValue,
106    pub sk_condition: Option<ResolvedSortKeyCondition>,
107}
108
109#[derive(Debug)]
110pub enum ResolvedSortKeyCondition {
111    Eq(String, AttributeValue),
112    Lt(String, AttributeValue),
113    Le(String, AttributeValue),
114    Gt(String, AttributeValue),
115    Ge(String, AttributeValue),
116    Between(String, AttributeValue, AttributeValue),
117    BeginsWith(String, AttributeValue),
118}
119
120impl ResolvedSortKeyCondition {
121    pub fn sk_name(&self) -> &str {
122        match self {
123            Self::Eq(n, _)
124            | Self::Lt(n, _)
125            | Self::Le(n, _)
126            | Self::Gt(n, _)
127            | Self::Ge(n, _)
128            | Self::Between(n, _, _)
129            | Self::BeginsWith(n, _) => n,
130        }
131    }
132
133    /// Convert to SQL WHERE clause components for sk column.
134    /// Returns (operator, value_string) pairs.
135    /// For BETWEEN, returns two conditions.
136    pub fn to_sql_conditions(&self) -> Vec<(String, String)> {
137        match self {
138            Self::Eq(_, v) => vec![("=".into(), val_to_key_string(v))],
139            Self::Lt(_, v) => vec![("<".into(), val_to_key_string(v))],
140            Self::Le(_, v) => vec![("<=".into(), val_to_key_string(v))],
141            Self::Gt(_, v) => vec![(">".into(), val_to_key_string(v))],
142            Self::Ge(_, v) => vec![(">=".into(), val_to_key_string(v))],
143            Self::Between(_, lo, hi) => vec![
144                (">=".into(), val_to_key_string(lo)),
145                ("<=".into(), val_to_key_string(hi)),
146            ],
147            Self::BeginsWith(_, prefix) => {
148                let prefix_str = val_to_key_string(prefix);
149                // Escape LIKE wildcards in the prefix value before appending %
150                let escaped = prefix_str
151                    .replace('\\', "\\\\")
152                    .replace('%', "\\%")
153                    .replace('_', "\\_");
154                vec![("LIKE".into(), format!("{escaped}%"))]
155            }
156        }
157    }
158}
159
160fn val_to_key_string(val: &AttributeValue) -> String {
161    val.to_key_string().unwrap_or_default()
162}
163
164fn resolve_sk_condition(
165    cond: &SortKeyCondition,
166    tracker: &TrackedExpressionAttributes,
167) -> Result<ResolvedSortKeyCondition, String> {
168    match cond {
169        SortKeyCondition::Eq(sk, vr) => {
170            let v = tracker.resolve_value(vr)?.clone();
171            Ok(ResolvedSortKeyCondition::Eq(sk.clone(), v))
172        }
173        SortKeyCondition::Lt(sk, vr) => {
174            let v = tracker.resolve_value(vr)?.clone();
175            Ok(ResolvedSortKeyCondition::Lt(sk.clone(), v))
176        }
177        SortKeyCondition::Le(sk, vr) => {
178            let v = tracker.resolve_value(vr)?.clone();
179            Ok(ResolvedSortKeyCondition::Le(sk.clone(), v))
180        }
181        SortKeyCondition::Gt(sk, vr) => {
182            let v = tracker.resolve_value(vr)?.clone();
183            Ok(ResolvedSortKeyCondition::Gt(sk.clone(), v))
184        }
185        SortKeyCondition::Ge(sk, vr) => {
186            let v = tracker.resolve_value(vr)?.clone();
187            Ok(ResolvedSortKeyCondition::Ge(sk.clone(), v))
188        }
189        SortKeyCondition::Between(sk, lo_ref, hi_ref) => {
190            let lo = tracker.resolve_value(lo_ref)?.clone();
191            let hi = tracker.resolve_value(hi_ref)?.clone();
192            // Validate same type
193            if std::mem::discriminant(&lo) != std::mem::discriminant(&hi) {
194                return Err(format!(
195                    "Invalid KeyConditionExpression: The BETWEEN operator requires same data type \
196                     for lower and upper bounds; lower bound operand: AttributeValue: {{{}}}, \
197                     upper bound operand: AttributeValue: {{{}}}",
198                    format_attr_value_short(&lo),
199                    format_attr_value_short(&hi)
200                ));
201            }
202            // Validate ordering (upper >= lower)
203            if !between_order_valid(&lo, &hi) {
204                return Err(format!(
205                    "Invalid KeyConditionExpression: The BETWEEN operator requires upper bound \
206                     to be greater than or equal to lower bound; lower bound operand: \
207                     AttributeValue: {{{}}}, upper bound operand: AttributeValue: {{{}}}",
208                    format_attr_value_short(&lo),
209                    format_attr_value_short(&hi)
210                ));
211            }
212            Ok(ResolvedSortKeyCondition::Between(sk.clone(), lo, hi))
213        }
214        SortKeyCondition::BeginsWith(sk, vr) => {
215            let v = tracker.resolve_value(vr)?.clone();
216            Ok(ResolvedSortKeyCondition::BeginsWith(sk.clone(), v))
217        }
218    }
219}
220
221// ---------------------------------------------------------------------------
222// Internal parsing helpers
223// ---------------------------------------------------------------------------
224
225#[derive(Debug)]
226enum ParsedCond {
227    Eq(String, String), // (attr_name, value_ref)
228    Lt(String, String),
229    Le(String, String),
230    Gt(String, String),
231    Ge(String, String),
232    Between(String, String, String), // (attr_name, lo_ref, hi_ref)
233    BeginsWith(String, String),      // (attr_name, prefix_ref)
234}
235
236impl ParsedCond {
237    fn into_sk_condition(self) -> Result<SortKeyCondition, String> {
238        match self {
239            ParsedCond::Eq(n, v) => Ok(SortKeyCondition::Eq(n, v)),
240            ParsedCond::Lt(n, v) => Ok(SortKeyCondition::Lt(n, v)),
241            ParsedCond::Le(n, v) => Ok(SortKeyCondition::Le(n, v)),
242            ParsedCond::Gt(n, v) => Ok(SortKeyCondition::Gt(n, v)),
243            ParsedCond::Ge(n, v) => Ok(SortKeyCondition::Ge(n, v)),
244            ParsedCond::Between(n, lo, hi) => Ok(SortKeyCondition::Between(n, lo, hi)),
245            ParsedCond::BeginsWith(n, v) => Ok(SortKeyCondition::BeginsWith(n, v)),
246        }
247    }
248}
249
250fn parse_single_condition(
251    stream: &mut TokenStream,
252    tracker: &TrackedExpressionAttributes,
253) -> Result<ParsedCond, String> {
254    // Check for begins_with function
255    if let Some(Token::Identifier(name)) = stream.peek() {
256        if name.to_lowercase() == "begins_with" {
257            stream.next();
258            stream.expect(&Token::LParen)?;
259            let path = parse_raw_path(stream)?;
260            let attr_name = resolve_path_to_name(&path, tracker)?;
261            stream.expect(&Token::Comma)?;
262            let val_ref = expect_value_ref(stream)?;
263            stream.expect(&Token::RParen)?;
264            return Ok(ParsedCond::BeginsWith(attr_name, val_ref));
265        }
266    }
267
268    // attr op :val
269    let path = parse_raw_path(stream)?;
270    let attr_name = resolve_path_to_name(&path, tracker)?;
271
272    match stream.next() {
273        Some(Token::Eq) => {
274            let val_ref = expect_value_ref(stream)?;
275            Ok(ParsedCond::Eq(attr_name, val_ref))
276        }
277        Some(Token::Lt) => {
278            let val_ref = expect_value_ref(stream)?;
279            Ok(ParsedCond::Lt(attr_name, val_ref))
280        }
281        Some(Token::Le) => {
282            let val_ref = expect_value_ref(stream)?;
283            Ok(ParsedCond::Le(attr_name, val_ref))
284        }
285        Some(Token::Gt) => {
286            let val_ref = expect_value_ref(stream)?;
287            Ok(ParsedCond::Gt(attr_name, val_ref))
288        }
289        Some(Token::Ge) => {
290            let val_ref = expect_value_ref(stream)?;
291            Ok(ParsedCond::Ge(attr_name, val_ref))
292        }
293        Some(Token::Between) => {
294            let lo_ref = expect_value_ref(stream)?;
295            stream.expect(&Token::And)?;
296            let hi_ref = expect_value_ref(stream)?;
297            Ok(ParsedCond::Between(attr_name, lo_ref, hi_ref))
298        }
299        Some(t) => Err(format!(
300            "Unexpected operator in KeyConditionExpression: {t}"
301        )),
302        None => Err("Unexpected end of KeyConditionExpression".to_string()),
303    }
304}
305
306fn resolve_path_to_name(
307    path: &[PathElement],
308    tracker: &TrackedExpressionAttributes,
309) -> Result<String, String> {
310    if path.len() != 1 {
311        return Err("KeyConditionExpression only supports top-level attributes".to_string());
312    }
313    match &path[0] {
314        PathElement::Attribute(name) => {
315            if name.starts_with('#') {
316                tracker.resolve_name(name)
317            } else {
318                Ok(name.clone())
319            }
320        }
321        PathElement::Index(_) => Err("KeyConditionExpression cannot use index paths".to_string()),
322    }
323}
324
325/// Format an attribute value for error messages (DynamoDB short format).
326fn format_attr_value_short(val: &AttributeValue) -> String {
327    match val {
328        AttributeValue::S(s) => format!("S:{s}"),
329        AttributeValue::N(n) => format!("N:{n}"),
330        AttributeValue::B(b) => {
331            use base64::Engine;
332            let encoded = base64::engine::general_purpose::STANDARD.encode(b);
333            format!("B:{encoded}")
334        }
335        AttributeValue::BOOL(b) => format!("BOOL:{b}"),
336        AttributeValue::NULL(_) => "NULL:true".to_string(),
337        AttributeValue::SS(set) => format!("SS:{:?}", set),
338        AttributeValue::NS(set) => format!("NS:{:?}", set),
339        AttributeValue::BS(_) => "BS:[...]".to_string(),
340        AttributeValue::L(_) => "L:[...]".to_string(),
341        AttributeValue::M(_) => "M:{...}".to_string(),
342    }
343}
344
345/// Check if BETWEEN bounds are in valid order (lo <= hi).
346fn between_order_valid(lo: &AttributeValue, hi: &AttributeValue) -> bool {
347    match (lo, hi) {
348        (AttributeValue::S(a), AttributeValue::S(b)) => a <= b,
349        (AttributeValue::N(a), AttributeValue::N(b)) => {
350            let a_f = a.parse::<f64>().unwrap_or(0.0);
351            let b_f = b.parse::<f64>().unwrap_or(0.0);
352            a_f <= b_f
353        }
354        (AttributeValue::B(a), AttributeValue::B(b)) => a <= b,
355        _ => true,
356    }
357}
358
359fn expect_value_ref(stream: &mut TokenStream) -> Result<String, String> {
360    match stream.next() {
361        Some(Token::ValueRef(name)) => Ok(name.clone()),
362        Some(t) => Err(format!("Expected value reference (:name), got {t}")),
363        None => Err("Expected value reference, got end of expression".to_string()),
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use std::collections::HashMap;
371
372    fn make_tracker<'a>(
373        names: &'a Option<HashMap<String, String>>,
374        values: &'a Option<HashMap<String, AttributeValue>>,
375    ) -> TrackedExpressionAttributes<'a> {
376        TrackedExpressionAttributes::new(names, values)
377    }
378
379    #[test]
380    fn test_pk_only() {
381        let no_names = None;
382        let no_values = None;
383        let tracker = make_tracker(&no_names, &no_values);
384        let kc = parse("pk = :pk", &tracker).unwrap();
385        assert_eq!(kc.pk_name, "pk");
386        assert_eq!(kc.pk_value_ref, ":pk");
387        assert!(kc.sk_condition.is_none());
388    }
389
390    #[test]
391    fn test_pk_and_sk_eq() {
392        let no_names = None;
393        let no_values = None;
394        let tracker = make_tracker(&no_names, &no_values);
395        let kc = parse("pk = :pk AND sk = :sk", &tracker).unwrap();
396        assert_eq!(kc.pk_name, "pk");
397        assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Eq(_, _))));
398    }
399
400    #[test]
401    fn test_pk_and_sk_between() {
402        let no_names = None;
403        let no_values = None;
404        let tracker = make_tracker(&no_names, &no_values);
405        let kc = parse("pk = :pk AND sk BETWEEN :lo AND :hi", &tracker).unwrap();
406        assert!(matches!(
407            kc.sk_condition,
408            Some(SortKeyCondition::Between(_, _, _))
409        ));
410    }
411
412    #[test]
413    fn test_pk_and_begins_with() {
414        let no_names = None;
415        let no_values = None;
416        let tracker = make_tracker(&no_names, &no_values);
417        let kc = parse("pk = :pk AND begins_with(sk, :prefix)", &tracker).unwrap();
418        assert!(matches!(
419            kc.sk_condition,
420            Some(SortKeyCondition::BeginsWith(_, _))
421        ));
422    }
423
424    #[test]
425    fn test_with_attribute_names() {
426        let an = Some(HashMap::from([
427            ("#pk".to_string(), "partitionKey".to_string()),
428            ("#sk".to_string(), "sortKey".to_string()),
429        ]));
430        let no_values = None;
431        let tracker = make_tracker(&an, &no_values);
432        let kc = parse("#pk = :pk AND #sk > :sk", &tracker).unwrap();
433        assert_eq!(kc.pk_name, "partitionKey");
434        assert!(matches!(kc.sk_condition, Some(SortKeyCondition::Gt(ref n, _)) if n == "sortKey"));
435    }
436
437    #[test]
438    fn test_resolve_values() {
439        let no_names = None;
440        let no_values = None;
441        let parse_tracker = make_tracker(&no_names, &no_values);
442        let kc = parse("pk = :pk AND sk >= :sk", &parse_tracker).unwrap();
443        let av = Some(HashMap::from([
444            (":pk".to_string(), AttributeValue::S("user#1".into())),
445            (":sk".to_string(), AttributeValue::S("2024-01-01".into())),
446        ]));
447        let resolve_tracker = make_tracker(&no_names, &av);
448        let resolved = resolve_values(&kc, &resolve_tracker).unwrap();
449        assert_eq!(resolved.pk_value, AttributeValue::S("user#1".into()));
450        assert!(matches!(
451            resolved.sk_condition,
452            Some(ResolvedSortKeyCondition::Ge(_, _))
453        ));
454    }
455
456    #[test]
457    fn test_sk_comparisons() {
458        let no_names = None;
459        let no_values = None;
460        for (op, variant) in [("<", "Lt"), ("<=", "Le"), (">", "Gt"), (">=", "Ge")] {
461            let tracker = make_tracker(&no_names, &no_values);
462            let kc = parse(&format!("pk = :pk AND sk {op} :sk"), &tracker).unwrap();
463            let sk = kc.sk_condition.unwrap();
464            let name = match &sk {
465                SortKeyCondition::Lt(n, _) => format!("Lt:{n}"),
466                SortKeyCondition::Le(n, _) => format!("Le:{n}"),
467                SortKeyCondition::Gt(n, _) => format!("Gt:{n}"),
468                SortKeyCondition::Ge(n, _) => format!("Ge:{n}"),
469                _ => "other".to_string(),
470            };
471            assert!(name.starts_with(variant), "Expected {variant}, got {name}");
472        }
473    }
474}