Skip to main content

dynoxide/expressions/
update.rs

1//! UpdateExpression parsing and evaluation.
2//!
3//! Supports SET, REMOVE, ADD, DELETE clauses.
4
5use crate::expressions::condition::parse_raw_path;
6use crate::expressions::tokenizer::{
7    Token, TokenStream, near_window_parser, near_window_tokenizer, tokenize,
8};
9use crate::expressions::{
10    PathElement, TrackedExpressionAttributes, remove_path, resolve_path, resolve_path_elements,
11    set_path,
12};
13use crate::types::AttributeValue;
14use std::collections::HashMap;
15
16/// Parsed update expression with all clause actions.
17#[derive(Debug)]
18pub struct UpdateExpr {
19    pub set_actions: Vec<SetAction>,
20    pub remove_actions: Vec<Vec<PathElement>>,
21    pub add_actions: Vec<AddAction>,
22    pub delete_actions: Vec<DeleteAction>,
23}
24
25/// A SET action: `path = value_expr`
26#[derive(Debug)]
27pub struct SetAction {
28    pub path: Vec<PathElement>,
29    pub value: SetValue,
30}
31
32/// Value expression for SET.
33#[derive(Debug)]
34pub enum SetValue {
35    /// Direct value or path reference
36    Operand(SetOperand),
37    /// `operand + operand`
38    Plus(SetOperand, SetOperand),
39    /// `operand - operand`
40    Minus(SetOperand, SetOperand),
41}
42
43/// An operand in a SET expression.
44#[derive(Debug)]
45pub enum SetOperand {
46    Path(Vec<PathElement>),
47    ValueRef(String),
48    IfNotExists(Vec<PathElement>, Box<SetOperand>),
49    ListAppend(Box<SetOperand>, Box<SetOperand>),
50    /// A parenthesised sub-expression, e.g. `(c - :v)`.
51    Group(Box<SetValue>),
52}
53
54/// An ADD action: `path :value`
55#[derive(Debug)]
56pub struct AddAction {
57    pub path: Vec<PathElement>,
58    pub value_ref: String,
59}
60
61/// A DELETE action: `path :value`
62#[derive(Debug)]
63pub struct DeleteAction {
64    pub path: Vec<PathElement>,
65    pub value_ref: String,
66}
67
68/// Parse an UpdateExpression string.
69pub fn parse(expr: &str) -> Result<UpdateExpr, String> {
70    let tokens = match tokenize(expr) {
71        Ok(t) => t,
72        Err(err) => {
73            // Tokenizer-level syntax error (e.g. stray `!` mid-expression):
74            // emit the same shape as parser-level errors, with a tokenizer-style
75            // near: window (offending byte plus at most one more non-whitespace byte).
76            let bad = &expr[err.position..err.position + err.bad_len];
77            let near = near_window_tokenizer(expr, err.position);
78            return Err(format!(
79                r#"Invalid UpdateExpression: Syntax error; token: "{bad}", near: "{near}""#
80            ));
81        }
82    };
83    let mut stream = TokenStream::new(tokens);
84
85    let mut set_actions = Vec::new();
86    let mut remove_actions = Vec::new();
87    let mut add_actions = Vec::new();
88    let mut delete_actions = Vec::new();
89
90    let mut seen_set = false;
91    let mut seen_remove = false;
92    let mut seen_add = false;
93    let mut seen_delete = false;
94
95    while !stream.at_end() {
96        match stream.peek() {
97            Some(Token::Set) => {
98                if seen_set {
99                    return Err("Invalid UpdateExpression: The \"SET\" section can only be used once in an update expression;".to_string());
100                }
101                seen_set = true;
102                stream.next();
103                parse_set_clause(&mut stream, &mut set_actions).map_err(wrap_syntax_error)?;
104            }
105            Some(Token::Remove) => {
106                if seen_remove {
107                    return Err("Invalid UpdateExpression: The \"REMOVE\" section can only be used once in an update expression;".to_string());
108                }
109                seen_remove = true;
110                stream.next();
111                parse_remove_clause(&mut stream, &mut remove_actions).map_err(wrap_syntax_error)?;
112            }
113            Some(Token::Add) => {
114                if seen_add {
115                    return Err("Invalid UpdateExpression: The \"ADD\" section can only be used once in an update expression;".to_string());
116                }
117                seen_add = true;
118                stream.next();
119                parse_add_clause(&mut stream, &mut add_actions).map_err(wrap_syntax_error)?;
120            }
121            Some(Token::Delete) => {
122                if seen_delete {
123                    return Err("Invalid UpdateExpression: The \"DELETE\" section can only be used once in an update expression;".to_string());
124                }
125                seen_delete = true;
126                stream.next();
127                parse_delete_clause(&mut stream, &mut delete_actions).map_err(wrap_syntax_error)?;
128            }
129            Some(_) => {
130                // Unexpected leading token where SET/REMOVE/ADD/DELETE was required.
131                // Build the AWS-style "token: \"X\", near: \"X Y\"" window from the
132                // offending token's span and the next token's span (if any).
133                let offending_span = stream
134                    .peek_span()
135                    .expect("peek_span must yield when peek did");
136                let bad = &expr[offending_span.start..offending_span.end()];
137                stream.next();
138                let next_span = stream.peek_span();
139                let near = near_window_parser(expr, offending_span, next_span);
140                return Err(format!(
141                    r#"Invalid UpdateExpression: Syntax error; token: "{bad}", near: "{near}""#
142                ));
143            }
144            None => break,
145        }
146    }
147
148    Ok(UpdateExpr {
149        set_actions,
150        remove_actions,
151        add_actions,
152        delete_actions,
153    })
154}
155
156/// Wrap a sub-parser error with the standard syntax error prefix,
157/// unless it already has a recognised higher-level prefix.
158fn wrap_syntax_error(err: String) -> String {
159    if err.starts_with("Invalid UpdateExpression:") {
160        err
161    } else if err.starts_with("Attribute name is a reserved keyword") {
162        format!("Invalid UpdateExpression: {err}")
163    } else {
164        format!("Invalid UpdateExpression: Syntax error; {err}")
165    }
166}
167
168/// Walk an UpdateExpr and track all attribute name and value references
169/// without actually evaluating or modifying any item. This is used for
170/// pre-validation: checking that all referenced names/values are defined,
171/// and detecting unused names/values.
172pub fn track_references(
173    expr: &UpdateExpr,
174    tracker: &TrackedExpressionAttributes,
175) -> Result<(), String> {
176    // Collect all target paths for overlap/conflict detection
177    let mut all_target_paths: Vec<Vec<PathElement>> = Vec::new();
178
179    for action in &expr.set_actions {
180        track_path_refs(&action.path, tracker)?;
181        track_set_value_refs(&action.value, tracker)?;
182        all_target_paths.push(resolve_tracked_path(&action.path, tracker));
183    }
184    for path in &expr.remove_actions {
185        track_path_refs(path, tracker)?;
186        all_target_paths.push(resolve_tracked_path(path, tracker));
187    }
188    for action in &expr.add_actions {
189        track_path_refs(&action.path, tracker)?;
190        let val = tracker.resolve_value(&action.value_ref)?;
191        // Validate ADD operand type statically
192        validate_add_type(val)?;
193        all_target_paths.push(resolve_tracked_path(&action.path, tracker));
194    }
195    for action in &expr.delete_actions {
196        track_path_refs(&action.path, tracker)?;
197        let val = tracker.resolve_value(&action.value_ref)?;
198        // Validate DELETE operand type statically
199        validate_delete_type(val)?;
200        all_target_paths.push(resolve_tracked_path(&action.path, tracker));
201    }
202
203    // Static type validation for SET value expressions
204    for action in &expr.set_actions {
205        validate_set_value_types(&action.value, tracker)?;
206    }
207
208    // Check for overlapping/conflicting paths
209    check_path_overlaps(&all_target_paths)?;
210
211    Ok(())
212}
213
214/// Validate that an ADD operand has a compatible type.
215fn validate_add_type(val: &crate::types::AttributeValue) -> Result<(), String> {
216    use crate::types::AttributeValue;
217    match val {
218        AttributeValue::N(_)
219        | AttributeValue::SS(_)
220        | AttributeValue::NS(_)
221        | AttributeValue::BS(_) => Ok(()),
222        _ => Err(format!(
223            "Invalid UpdateExpression: Incorrect operand type for operator or function; \
224             operator: ADD, operand type: {}",
225            dynamo_type_name(val)
226        )),
227    }
228}
229
230/// Validate that a DELETE operand has a compatible type.
231fn validate_delete_type(val: &crate::types::AttributeValue) -> Result<(), String> {
232    use crate::types::AttributeValue;
233    match val {
234        AttributeValue::SS(_) | AttributeValue::NS(_) | AttributeValue::BS(_) => Ok(()),
235        _ => Err(format!(
236            "Invalid UpdateExpression: Incorrect operand type for operator or function; \
237             operator: DELETE, operand type: {}",
238            dynamo_type_name(val)
239        )),
240    }
241}
242
243/// Map an AttributeValue to its DynamoDB type name for error messages.
244fn dynamo_type_name(val: &crate::types::AttributeValue) -> &'static str {
245    use crate::types::AttributeValue;
246    match val {
247        AttributeValue::S(_) => "STRING",
248        AttributeValue::N(_) => "NUMBER",
249        AttributeValue::B(_) => "BINARY",
250        AttributeValue::BOOL(_) => "BOOLEAN",
251        AttributeValue::NULL(_) => "NULL",
252        AttributeValue::SS(_) => "SS",
253        AttributeValue::NS(_) => "NS",
254        AttributeValue::BS(_) => "BS",
255        AttributeValue::L(_) => "LIST",
256        AttributeValue::M(_) => "MAP",
257    }
258}
259
260/// Validate types for SET value expressions (arithmetic, list_append).
261fn validate_set_value_types(
262    value: &SetValue,
263    tracker: &TrackedExpressionAttributes,
264) -> Result<(), String> {
265    match value {
266        SetValue::Operand(op) => validate_set_operand_types(op, tracker),
267        SetValue::Plus(left, right) => {
268            validate_arithmetic_operand(left, "+", tracker)?;
269            validate_arithmetic_operand(right, "+", tracker)
270        }
271        SetValue::Minus(left, right) => {
272            validate_arithmetic_operand(left, "-", tracker)?;
273            validate_arithmetic_operand(right, "-", tracker)
274        }
275    }
276}
277
278/// Validate that an operand used in + or - is a number (if it's a value ref).
279fn validate_arithmetic_operand(
280    operand: &SetOperand,
281    op: &str,
282    tracker: &TrackedExpressionAttributes,
283) -> Result<(), String> {
284    use crate::types::AttributeValue;
285    match operand {
286        SetOperand::ValueRef(name) => {
287            let val = tracker.resolve_value(name)?;
288            if !matches!(val, AttributeValue::N(_)) {
289                return Err(format!(
290                    "Invalid UpdateExpression: Incorrect operand type for operator or function; \
291                     operator or function: {op}, operand type: {}",
292                    dynamo_type_name(val)
293                ));
294            }
295            Ok(())
296        }
297        SetOperand::IfNotExists(_, default) => validate_set_operand_types(default, tracker),
298        SetOperand::ListAppend(a, b) => {
299            validate_list_append_operand(a, tracker)?;
300            validate_list_append_operand(b, tracker)
301        }
302        SetOperand::Path(_) => Ok(()), // Path types checked at runtime
303        // A parenthesised group resolves to a number at runtime; validate its
304        // inner expression but leave the numeric check to evaluation.
305        SetOperand::Group(inner) => validate_set_value_types(inner, tracker),
306    }
307}
308
309/// Validate types for a set operand (recursively).
310fn validate_set_operand_types(
311    operand: &SetOperand,
312    tracker: &TrackedExpressionAttributes,
313) -> Result<(), String> {
314    match operand {
315        SetOperand::ListAppend(a, b) => {
316            validate_list_append_operand(a, tracker)?;
317            validate_list_append_operand(b, tracker)
318        }
319        SetOperand::IfNotExists(_, default) => validate_set_operand_types(default, tracker),
320        SetOperand::Group(inner) => validate_set_value_types(inner, tracker),
321        _ => Ok(()),
322    }
323}
324
325/// Validate a list_append operand is a list if it's a value ref.
326fn validate_list_append_operand(
327    operand: &SetOperand,
328    tracker: &TrackedExpressionAttributes,
329) -> Result<(), String> {
330    use crate::types::AttributeValue;
331    if let SetOperand::ValueRef(name) = operand {
332        let val = tracker.resolve_value(name)?;
333        if !matches!(val, AttributeValue::L(_)) {
334            return Err(format!(
335                "Invalid UpdateExpression: Incorrect operand type for operator or function; \
336                 operator or function: list_append, operand type: {}",
337                dynamo_type_name(val)
338            ));
339        }
340    }
341    Ok(())
342}
343
344/// Resolve path elements to their final names (expanding #name refs).
345fn resolve_tracked_path(
346    path: &[PathElement],
347    tracker: &TrackedExpressionAttributes,
348) -> Vec<PathElement> {
349    path.iter()
350        .map(|elem| {
351            if let PathElement::Attribute(name) = elem {
352                if name.starts_with('#') {
353                    if let Ok(resolved) = tracker.resolve_name(name) {
354                        return PathElement::Attribute(resolved);
355                    }
356                }
357            }
358            elem.clone()
359        })
360        .collect()
361}
362
363/// Format a path for error messages in dynalite format: [a, b, [1], c].
364fn format_path_for_error(path: &[PathElement]) -> String {
365    let parts: Vec<String> = path
366        .iter()
367        .map(|elem| match elem {
368            PathElement::Attribute(name) => name.clone(),
369            PathElement::Index(i) => format!("[{i}]"),
370        })
371        .collect();
372    format!("[{}]", parts.join(", "))
373}
374
375/// Check for overlapping or conflicting document paths.
376///
377/// Two paths overlap if one is a prefix of the other (e.g., `a.b` and `a.b.c`).
378/// Two paths conflict if they share elements but diverge in type at the same
379/// position (e.g., `a[3].c` and `a.c[3]`).
380fn check_path_overlaps(paths: &[Vec<PathElement>]) -> Result<(), String> {
381    for i in 0..paths.len() {
382        for j in (i + 1)..paths.len() {
383            let a = &paths[i];
384            let b = &paths[j];
385            let min_len = a.len().min(b.len());
386
387            // Check common prefix length
388            let mut common = 0;
389            for k in 0..min_len {
390                if a[k] == b[k] {
391                    common += 1;
392                } else {
393                    break;
394                }
395            }
396
397            if common == 0 {
398                continue;
399            }
400
401            // If one path is a prefix of the other, they overlap
402            if common == a.len() || common == b.len() {
403                let (shorter, longer) = if a.len() <= b.len() { (a, b) } else { (b, a) };
404                return Err(format!(
405                    "Invalid UpdateExpression: Two document paths overlap with each other; \
406                     must remove or rewrite one of these paths; \
407                     path one: {}, path two: {}",
408                    format_path_for_error(longer),
409                    format_path_for_error(shorter)
410                ));
411            }
412
413            // If paths share a prefix but diverge, they conflict
414            if common > 0 && common < min_len && a == b {
415                return Err(format!(
416                    "Invalid UpdateExpression: Two document paths conflict with each other; \
417                     must remove or rewrite one of these paths; \
418                     path one: {}, path two: {}",
419                    format_path_for_error(a),
420                    format_path_for_error(b)
421                ));
422            }
423        }
424    }
425    Ok(())
426}
427
428fn track_path_refs(
429    path: &[PathElement],
430    tracker: &TrackedExpressionAttributes,
431) -> Result<(), String> {
432    for elem in path {
433        if let PathElement::Attribute(name) = elem {
434            if name.starts_with('#') {
435                tracker.resolve_name(name)?;
436            }
437        }
438    }
439    Ok(())
440}
441
442fn track_set_value_refs(
443    value: &SetValue,
444    tracker: &TrackedExpressionAttributes,
445) -> Result<(), String> {
446    match value {
447        SetValue::Operand(op) => track_set_operand_refs(op, tracker),
448        SetValue::Plus(left, right) | SetValue::Minus(left, right) => {
449            track_set_operand_refs(left, tracker)?;
450            track_set_operand_refs(right, tracker)
451        }
452    }
453}
454
455fn track_set_operand_refs(
456    operand: &SetOperand,
457    tracker: &TrackedExpressionAttributes,
458) -> Result<(), String> {
459    match operand {
460        SetOperand::Path(path) => track_path_refs(path, tracker),
461        SetOperand::ValueRef(name) => {
462            tracker.resolve_value(name)?;
463            Ok(())
464        }
465        SetOperand::IfNotExists(path, default) => {
466            track_path_refs(path, tracker)?;
467            track_set_operand_refs(default, tracker)
468        }
469        SetOperand::ListAppend(a, b) => {
470            track_set_operand_refs(a, tracker)?;
471            track_set_operand_refs(b, tracker)
472        }
473        SetOperand::Group(inner) => track_set_value_refs(inner, tracker),
474    }
475}
476
477/// Apply an update expression to an item (mutating it in place), tracking attribute usage.
478pub fn apply(
479    item: &mut HashMap<String, AttributeValue>,
480    expr: &UpdateExpr,
481    tracker: &TrackedExpressionAttributes,
482) -> Result<(), String> {
483    // Process SET actions.
484    //
485    // Every SET right-hand side is evaluated against the pre-update image, so
486    // that `SET a = :v, b = a` gives `b` the OLD value of `a` rather than the
487    // value assigned to `a` earlier in the same expression. DynamoDB applies
488    // the whole expression to the item as it appeared before the update, so all
489    // reads see the original snapshot. (Overlapping target paths are rejected by
490    // `check_path_overlaps`, so no SET can legitimately read another's output.)
491    let snapshot = item.clone();
492    for action in &expr.set_actions {
493        let resolved_path = resolve_path_elements(&action.path, tracker)?;
494        let value = evaluate_set_value(&action.value, &snapshot, tracker)?;
495        set_path(item, &resolved_path, value)?;
496    }
497
498    // Process REMOVE actions
499    for path in &expr.remove_actions {
500        let resolved_path = resolve_path_elements(path, tracker)?;
501        remove_path(item, &resolved_path)?;
502    }
503
504    // Process ADD actions
505    for action in &expr.add_actions {
506        let resolved_path = resolve_path_elements(&action.path, tracker)?;
507        let add_val = tracker.resolve_value(&action.value_ref)?.clone();
508        apply_add(item, &resolved_path, &add_val).map_err(|_| {
509            "An operand in the update expression has an incorrect data type".to_string()
510        })?;
511    }
512
513    // Process DELETE actions
514    for action in &expr.delete_actions {
515        let resolved_path = resolve_path_elements(&action.path, tracker)?;
516        let del_val = tracker.resolve_value(&action.value_ref)?.clone();
517        apply_delete(item, &resolved_path, &del_val).map_err(|_| {
518            "An operand in the update expression has an incorrect data type".to_string()
519        })?;
520    }
521
522    Ok(())
523}
524
525// ---------------------------------------------------------------------------
526// SET value evaluation
527// ---------------------------------------------------------------------------
528
529fn evaluate_set_value(
530    value: &SetValue,
531    item: &HashMap<String, AttributeValue>,
532    tracker: &TrackedExpressionAttributes,
533) -> Result<AttributeValue, String> {
534    match value {
535        SetValue::Operand(op) => evaluate_set_operand(op, item, tracker),
536        SetValue::Plus(left, right) => {
537            let lv = evaluate_set_operand(left, item, tracker)?;
538            let rv = evaluate_set_operand(right, item, tracker)?;
539            match (&lv, &rv) {
540                (AttributeValue::N(a), AttributeValue::N(b)) => {
541                    use bigdecimal::BigDecimal;
542                    use std::str::FromStr;
543                    let da = BigDecimal::from_str(a).map_err(|_| format!("Invalid number: {a}"))?;
544                    let db = BigDecimal::from_str(b).map_err(|_| format!("Invalid number: {b}"))?;
545                    let result = &da + &db;
546                    Ok(AttributeValue::N(format_number(&result)))
547                }
548                _ => Err("Operands for + must be numbers".to_string()),
549            }
550        }
551        SetValue::Minus(left, right) => {
552            let lv = evaluate_set_operand(left, item, tracker)?;
553            let rv = evaluate_set_operand(right, item, tracker)?;
554            match (&lv, &rv) {
555                (AttributeValue::N(a), AttributeValue::N(b)) => {
556                    use bigdecimal::BigDecimal;
557                    use std::str::FromStr;
558                    let da = BigDecimal::from_str(a).map_err(|_| format!("Invalid number: {a}"))?;
559                    let db = BigDecimal::from_str(b).map_err(|_| format!("Invalid number: {b}"))?;
560                    let result = &da - &db;
561                    Ok(AttributeValue::N(format_number(&result)))
562                }
563                _ => Err("Operands for - must be numbers".to_string()),
564            }
565        }
566    }
567}
568
569fn evaluate_set_operand(
570    operand: &SetOperand,
571    item: &HashMap<String, AttributeValue>,
572    tracker: &TrackedExpressionAttributes,
573) -> Result<AttributeValue, String> {
574    match operand {
575        SetOperand::Path(path) => {
576            let resolved = resolve_path_elements(path, tracker)?;
577            resolve_path(item, &resolved).ok_or_else(|| {
578                "The provided expression refers to an attribute that does not exist in the item"
579                    .to_string()
580            })
581        }
582        SetOperand::ValueRef(name) => Ok(tracker.resolve_value(name)?.clone()),
583        SetOperand::IfNotExists(path, default) => {
584            let resolved = resolve_path_elements(path, tracker)?;
585            match resolve_path(item, &resolved) {
586                Some(existing) => Ok(existing),
587                None => evaluate_set_operand(default, item, tracker),
588            }
589        }
590        SetOperand::ListAppend(list1, list2) => {
591            let v1 = evaluate_set_operand(list1, item, tracker)?;
592            let v2 = evaluate_set_operand(list2, item, tracker)?;
593            match (v1, v2) {
594                (AttributeValue::L(mut a), AttributeValue::L(b)) => {
595                    a.extend(b);
596                    Ok(AttributeValue::L(a))
597                }
598                _ => Err("list_append requires two list operands".to_string()),
599            }
600        }
601        SetOperand::Group(inner) => evaluate_set_value(inner, item, tracker),
602    }
603}
604
605// ---------------------------------------------------------------------------
606// ADD action
607// ---------------------------------------------------------------------------
608
609/// Public wrapper for use by legacy `AttributeUpdates` support.
610pub fn apply_add_public(
611    item: &mut HashMap<String, AttributeValue>,
612    path: &[PathElement],
613    add_val: &AttributeValue,
614) -> Result<(), String> {
615    apply_add(item, path, add_val)
616}
617
618fn apply_add(
619    item: &mut HashMap<String, AttributeValue>,
620    path: &[PathElement],
621    add_val: &AttributeValue,
622) -> Result<(), String> {
623    let existing = resolve_path(item, path);
624
625    match (existing, add_val) {
626        // Number: add to existing number or create
627        (Some(AttributeValue::N(existing_n)), AttributeValue::N(add_n)) => {
628            use bigdecimal::BigDecimal;
629            use std::str::FromStr;
630            let de = BigDecimal::from_str(&existing_n)
631                .map_err(|_| format!("Invalid number: {existing_n}"))?;
632            let da = BigDecimal::from_str(add_n).map_err(|_| format!("Invalid number: {add_n}"))?;
633            let result = &de + &da;
634            set_path(item, path, AttributeValue::N(format_number(&result)))
635        }
636        (None, AttributeValue::N(_)) => {
637            // Create with the provided value
638            set_path(item, path, add_val.clone())
639        }
640
641        // String set: union
642        (Some(AttributeValue::SS(mut existing_set)), AttributeValue::SS(add_set)) => {
643            for s in add_set {
644                if !existing_set.contains(s) {
645                    existing_set.push(s.clone());
646                }
647            }
648            set_path(item, path, AttributeValue::SS(existing_set))
649        }
650        (None, AttributeValue::SS(_)) => set_path(item, path, add_val.clone()),
651
652        // Number set: union
653        (Some(AttributeValue::NS(mut existing_set)), AttributeValue::NS(add_set)) => {
654            for n in add_set {
655                if !existing_set.contains(n) {
656                    existing_set.push(n.clone());
657                }
658            }
659            set_path(item, path, AttributeValue::NS(existing_set))
660        }
661        (None, AttributeValue::NS(_)) => set_path(item, path, add_val.clone()),
662
663        // Binary set: union
664        (Some(AttributeValue::BS(mut existing_set)), AttributeValue::BS(add_set)) => {
665            for b in add_set {
666                if !existing_set.contains(b) {
667                    existing_set.push(b.clone());
668                }
669            }
670            set_path(item, path, AttributeValue::BS(existing_set))
671        }
672        (None, AttributeValue::BS(_)) => set_path(item, path, add_val.clone()),
673
674        // List: append elements (legacy AttributeUpdates behaviour)
675        (Some(AttributeValue::L(mut existing_list)), AttributeValue::L(add_list)) => {
676            existing_list.extend(add_list.iter().cloned());
677            set_path(item, path, AttributeValue::L(existing_list))
678        }
679        (None, AttributeValue::L(_)) => set_path(item, path, add_val.clone()),
680
681        _ => Err("Type mismatch for attribute to update".to_string()),
682    }
683}
684
685// ---------------------------------------------------------------------------
686// DELETE action
687// ---------------------------------------------------------------------------
688
689/// Public wrapper for use by legacy `AttributeUpdates` support.
690pub fn apply_delete_public(
691    item: &mut HashMap<String, AttributeValue>,
692    path: &[PathElement],
693    del_val: &AttributeValue,
694) -> Result<(), String> {
695    apply_delete(item, path, del_val)
696}
697
698fn apply_delete(
699    item: &mut HashMap<String, AttributeValue>,
700    path: &[PathElement],
701    del_val: &AttributeValue,
702) -> Result<(), String> {
703    let existing = resolve_path(item, path);
704
705    match (existing, del_val) {
706        (Some(AttributeValue::SS(existing_set)), AttributeValue::SS(del_set)) => {
707            let new_set: Vec<String> = existing_set
708                .into_iter()
709                .filter(|s| !del_set.contains(s))
710                .collect();
711            if new_set.is_empty() {
712                remove_path(item, path)
713            } else {
714                set_path(item, path, AttributeValue::SS(new_set))
715            }
716        }
717        (Some(AttributeValue::NS(existing_set)), AttributeValue::NS(del_set)) => {
718            let new_set: Vec<String> = existing_set
719                .into_iter()
720                .filter(|n| !del_set.contains(n))
721                .collect();
722            if new_set.is_empty() {
723                remove_path(item, path)
724            } else {
725                set_path(item, path, AttributeValue::NS(new_set))
726            }
727        }
728        (Some(AttributeValue::BS(existing_set)), AttributeValue::BS(del_set)) => {
729            let new_set: Vec<Vec<u8>> = existing_set
730                .into_iter()
731                .filter(|b| !del_set.contains(b))
732                .collect();
733            if new_set.is_empty() {
734                remove_path(item, path)
735            } else {
736                set_path(item, path, AttributeValue::BS(new_set))
737            }
738        }
739        (None, _) => Ok(()), // Nothing to delete from
740        _ => Err("Type mismatch for attribute to update".to_string()),
741    }
742}
743
744// ---------------------------------------------------------------------------
745// Parser
746// ---------------------------------------------------------------------------
747
748fn parse_set_clause(stream: &mut TokenStream, actions: &mut Vec<SetAction>) -> Result<(), String> {
749    actions.push(parse_set_action(stream)?);
750    while matches!(stream.peek(), Some(Token::Comma)) {
751        stream.next();
752        actions.push(parse_set_action(stream)?);
753    }
754    Ok(())
755}
756
757fn parse_set_action(stream: &mut TokenStream) -> Result<SetAction, String> {
758    let path = parse_raw_path(stream)?;
759    stream.expect(&Token::Eq)?;
760    let value = parse_set_value(stream)?;
761    Ok(SetAction { path, value })
762}
763
764fn parse_set_value(stream: &mut TokenStream) -> Result<SetValue, String> {
765    let left = parse_set_operand(stream)?;
766
767    match stream.peek() {
768        Some(Token::Plus) => {
769            stream.next();
770            let right = parse_set_operand(stream)?;
771            Ok(SetValue::Plus(left, right))
772        }
773        Some(Token::Minus) => {
774            stream.next();
775            let right = parse_set_operand(stream)?;
776            Ok(SetValue::Minus(left, right))
777        }
778        _ => Ok(SetValue::Operand(left)),
779    }
780}
781
782fn parse_set_operand(stream: &mut TokenStream) -> Result<SetOperand, String> {
783    // Check for functions: if_not_exists, list_append
784    if let Some(Token::Identifier(name)) = stream.peek() {
785        let func_name = name.to_lowercase();
786        let orig_name = name.clone();
787        match func_name.as_str() {
788            "if_not_exists" => {
789                stream.next();
790                stream.expect(&Token::LParen)?;
791
792                // First argument must be a document path (not a value ref or function)
793                match stream.peek() {
794                    Some(Token::ValueRef(_)) => {
795                        return Err(
796                            "Invalid UpdateExpression: Operator or function requires a document path; \
797                             operator or function: if_not_exists".to_string()
798                        );
799                    }
800                    Some(Token::Identifier(fname))
801                        if fname.to_lowercase() == "if_not_exists"
802                            || fname.to_lowercase() == "list_append" =>
803                    {
804                        return Err(
805                            "Invalid UpdateExpression: Operator or function requires a document path; \
806                             operator or function: if_not_exists".to_string()
807                        );
808                    }
809                    _ => {}
810                }
811
812                let path = parse_raw_path(stream)?;
813
814                // Check for correct number of operands
815                if !matches!(stream.peek(), Some(Token::Comma)) {
816                    return Err(
817                        "Invalid UpdateExpression: Incorrect number of operands for operator or function; \
818                         operator or function: if_not_exists, number of operands: 1".to_string()
819                    );
820                }
821                stream.expect(&Token::Comma)?;
822                let default = parse_set_operand(stream)?;
823                stream.expect(&Token::RParen)?;
824                return Ok(SetOperand::IfNotExists(path, Box::new(default)));
825            }
826            "list_append" => {
827                stream.next();
828                stream.expect(&Token::LParen)?;
829                let list1 = parse_set_operand(stream)?;
830
831                // Check for correct number of operands
832                if !matches!(stream.peek(), Some(Token::Comma)) {
833                    return Err(
834                        "Invalid UpdateExpression: Incorrect number of operands for operator or function; \
835                         operator or function: list_append, number of operands: 1".to_string()
836                    );
837                }
838                stream.expect(&Token::Comma)?;
839                let list2 = parse_set_operand(stream)?;
840                stream.expect(&Token::RParen)?;
841                return Ok(SetOperand::ListAppend(Box::new(list1), Box::new(list2)));
842            }
843            _ => {
844                // Check if this looks like a function call (identifier followed by '(')
845                // If so, report "Invalid function name" for unknown functions.
846                let saved_pos = stream.pos();
847                stream.next();
848                if matches!(stream.peek(), Some(Token::LParen)) {
849                    return Err(format!(
850                        "Invalid UpdateExpression: Invalid function name; function: {}",
851                        orig_name
852                    ));
853                }
854                // Rewind — not a function call, treat as path
855                stream.set_pos(saved_pos);
856            }
857        }
858    }
859
860    match stream.peek() {
861        // Parenthesised sub-expression, e.g. `(c - :v)`. The contents are a full
862        // SET value (operand or arithmetic), evaluated on the same BigDecimal path.
863        Some(Token::LParen) => {
864            stream.next();
865            let inner = parse_set_value(stream)?;
866            stream.expect(&Token::RParen)?;
867            Ok(SetOperand::Group(Box::new(inner)))
868        }
869        Some(Token::ValueRef(_)) => {
870            if let Some(Token::ValueRef(name)) = stream.next().cloned() {
871                Ok(SetOperand::ValueRef(name))
872            } else {
873                unreachable!()
874            }
875        }
876        Some(Token::Identifier(_)) | Some(Token::NameRef(_)) => {
877            let path = parse_raw_path(stream)?;
878            Ok(SetOperand::Path(path))
879        }
880        Some(t) => Err(format!("Expected operand in SET, got {t}")),
881        None => Err("Expected operand in SET, got end of expression".to_string()),
882    }
883}
884
885fn parse_remove_clause(
886    stream: &mut TokenStream,
887    actions: &mut Vec<Vec<PathElement>>,
888) -> Result<(), String> {
889    actions.push(parse_raw_path(stream)?);
890    while matches!(stream.peek(), Some(Token::Comma)) {
891        stream.next();
892        actions.push(parse_raw_path(stream)?);
893    }
894    Ok(())
895}
896
897fn parse_add_clause(stream: &mut TokenStream, actions: &mut Vec<AddAction>) -> Result<(), String> {
898    actions.push(parse_add_action(stream)?);
899    while matches!(stream.peek(), Some(Token::Comma)) {
900        stream.next();
901        actions.push(parse_add_action(stream)?);
902    }
903    Ok(())
904}
905
906fn parse_add_action(stream: &mut TokenStream) -> Result<AddAction, String> {
907    let path = parse_raw_path(stream)?;
908    match stream.next() {
909        Some(Token::ValueRef(name)) => Ok(AddAction {
910            path,
911            value_ref: name.clone(),
912        }),
913        Some(t) => Err(format!("Expected value reference in ADD, got {t}")),
914        None => Err("Expected value reference in ADD, got end of expression".to_string()),
915    }
916}
917
918fn parse_delete_clause(
919    stream: &mut TokenStream,
920    actions: &mut Vec<DeleteAction>,
921) -> Result<(), String> {
922    actions.push(parse_delete_action(stream)?);
923    while matches!(stream.peek(), Some(Token::Comma)) {
924        stream.next();
925        actions.push(parse_delete_action(stream)?);
926    }
927    Ok(())
928}
929
930fn parse_delete_action(stream: &mut TokenStream) -> Result<DeleteAction, String> {
931    let path = parse_raw_path(stream)?;
932    match stream.next() {
933        Some(Token::ValueRef(name)) => Ok(DeleteAction {
934            path,
935            value_ref: name.clone(),
936        }),
937        Some(t) => Err(format!("Expected value reference in DELETE, got {t}")),
938        None => Err("Expected value reference in DELETE, got end of expression".to_string()),
939    }
940}
941
942/// Format a BigDecimal number, stripping unnecessary trailing zeros.
943/// DynamoDB returns numbers without scientific notation.
944fn format_number(n: &bigdecimal::BigDecimal) -> String {
945    let normalized = n.normalized();
946    // Force scale >= 0 so BigDecimal renders without scientific notation.
947    // When the exponent is negative (large integer like 1e38), with_scale(0)
948    // expands to full decimal digits.
949    if normalized.as_bigint_and_exponent().1 < 0 {
950        normalized.with_scale(0).to_string()
951    } else {
952        normalized.to_string()
953    }
954}
955
956#[cfg(test)]
957mod tests {
958    use super::*;
959
960    fn make_item(pairs: &[(&str, AttributeValue)]) -> HashMap<String, AttributeValue> {
961        pairs
962            .iter()
963            .map(|(k, v)| (k.to_string(), v.clone()))
964            .collect()
965    }
966
967    fn vals(pairs: &[(&str, AttributeValue)]) -> Option<HashMap<String, AttributeValue>> {
968        Some(make_item(pairs))
969    }
970
971    fn make_tracker<'a>(
972        names: &'a Option<HashMap<String, String>>,
973        values: &'a Option<HashMap<String, AttributeValue>>,
974    ) -> TrackedExpressionAttributes<'a> {
975        TrackedExpressionAttributes::new(names, values)
976    }
977
978    #[test]
979    fn test_set_simple() {
980        let expr = parse("SET label = :val").unwrap();
981        assert_eq!(expr.set_actions.len(), 1);
982        assert!(expr.remove_actions.is_empty());
983    }
984
985    #[test]
986    fn test_set_multiple() {
987        let expr = parse("SET a = :v1, b = :v2").unwrap();
988        assert_eq!(expr.set_actions.len(), 2);
989    }
990
991    #[test]
992    fn test_set_arithmetic_plus() {
993        let expr = parse("SET tally = tally + :inc").unwrap();
994        let mut item = make_item(&[
995            ("pk", AttributeValue::S("k".into())),
996            ("tally", AttributeValue::N("10".into())),
997        ]);
998        let av = vals(&[(":inc", AttributeValue::N("5".into()))]);
999        let no_names = None;
1000        let tracker = make_tracker(&no_names, &av);
1001        apply(&mut item, &expr, &tracker).unwrap();
1002        assert_eq!(item["tally"], AttributeValue::N("15".into()));
1003    }
1004
1005    #[test]
1006    fn test_set_arithmetic_minus() {
1007        let expr = parse("SET price = price - :discount").unwrap();
1008        let mut item = make_item(&[
1009            ("pk", AttributeValue::S("k".into())),
1010            ("price", AttributeValue::N("100".into())),
1011        ]);
1012        let av = vals(&[(":discount", AttributeValue::N("25".into()))]);
1013        let no_names = None;
1014        let tracker = make_tracker(&no_names, &av);
1015        apply(&mut item, &expr, &tracker).unwrap();
1016        assert_eq!(item["price"], AttributeValue::N("75".into()));
1017    }
1018
1019    #[test]
1020    fn test_set_if_not_exists() {
1021        let expr = parse("SET hits = if_not_exists(hits, :zero)").unwrap();
1022        let mut item = make_item(&[("pk", AttributeValue::S("k".into()))]);
1023        let av = vals(&[(":zero", AttributeValue::N("0".into()))]);
1024        let no_names = None;
1025        let tracker = make_tracker(&no_names, &av);
1026        apply(&mut item, &expr, &tracker).unwrap();
1027        assert_eq!(item["hits"], AttributeValue::N("0".into()));
1028
1029        // Apply again — existing value should be preserved
1030        let tracker2 = make_tracker(&no_names, &av);
1031        apply(&mut item, &expr, &tracker2).unwrap();
1032        assert_eq!(item["hits"], AttributeValue::N("0".into()));
1033    }
1034
1035    #[test]
1036    fn test_set_list_append() {
1037        let expr = parse("SET entries = list_append(entries, :new)").unwrap();
1038        let mut item = make_item(&[
1039            ("pk", AttributeValue::S("k".into())),
1040            (
1041                "entries",
1042                AttributeValue::L(vec![AttributeValue::S("a".into())]),
1043            ),
1044        ]);
1045        let av = vals(&[(
1046            ":new",
1047            AttributeValue::L(vec![AttributeValue::S("b".into())]),
1048        )]);
1049        let no_names = None;
1050        let tracker = make_tracker(&no_names, &av);
1051        apply(&mut item, &expr, &tracker).unwrap();
1052        if let AttributeValue::L(list) = &item["entries"] {
1053            assert_eq!(list.len(), 2);
1054        } else {
1055            panic!("Expected list");
1056        }
1057    }
1058
1059    #[test]
1060    fn test_remove() {
1061        let expr = parse("REMOVE attr1, attr2").unwrap();
1062        let mut item = make_item(&[
1063            ("pk", AttributeValue::S("k".into())),
1064            ("attr1", AttributeValue::S("a".into())),
1065            ("attr2", AttributeValue::S("b".into())),
1066            ("attr3", AttributeValue::S("c".into())),
1067        ]);
1068        let no_names = None;
1069        let no_values = None;
1070        let tracker = make_tracker(&no_names, &no_values);
1071        apply(&mut item, &expr, &tracker).unwrap();
1072        assert!(!item.contains_key("attr1"));
1073        assert!(!item.contains_key("attr2"));
1074        assert!(item.contains_key("attr3"));
1075    }
1076
1077    #[test]
1078    fn test_add_number() {
1079        let expr = parse("ADD tally :inc").unwrap();
1080        let mut item = make_item(&[
1081            ("pk", AttributeValue::S("k".into())),
1082            ("tally", AttributeValue::N("10".into())),
1083        ]);
1084        let av = vals(&[(":inc", AttributeValue::N("5".into()))]);
1085        let no_names = None;
1086        let tracker = make_tracker(&no_names, &av);
1087        apply(&mut item, &expr, &tracker).unwrap();
1088        assert_eq!(item["tally"], AttributeValue::N("15".into()));
1089    }
1090
1091    #[test]
1092    fn test_add_number_create() {
1093        let expr = parse("ADD tally :val").unwrap();
1094        let mut item = make_item(&[("pk", AttributeValue::S("k".into()))]);
1095        let av = vals(&[(":val", AttributeValue::N("1".into()))]);
1096        let no_names = None;
1097        let tracker = make_tracker(&no_names, &av);
1098        apply(&mut item, &expr, &tracker).unwrap();
1099        assert_eq!(item["tally"], AttributeValue::N("1".into()));
1100    }
1101
1102    #[test]
1103    fn test_add_string_set() {
1104        let expr = parse("ADD colors :new_colors").unwrap();
1105        let mut item = make_item(&[
1106            ("pk", AttributeValue::S("k".into())),
1107            (
1108                "colors",
1109                AttributeValue::SS(vec!["red".into(), "blue".into()]),
1110            ),
1111        ]);
1112        let av = vals(&[(
1113            ":new_colors",
1114            AttributeValue::SS(vec!["blue".into(), "green".into()]),
1115        )]);
1116        let no_names = None;
1117        let tracker = make_tracker(&no_names, &av);
1118        apply(&mut item, &expr, &tracker).unwrap();
1119        if let AttributeValue::SS(set) = &item["colors"] {
1120            assert_eq!(set.len(), 3); // red, blue, green (blue deduplicated)
1121            assert!(set.contains(&"green".to_string()));
1122        } else {
1123            panic!("Expected SS");
1124        }
1125    }
1126
1127    #[test]
1128    fn test_delete_string_set() {
1129        let expr = parse("DELETE colors :remove").unwrap();
1130        let mut item = make_item(&[
1131            ("pk", AttributeValue::S("k".into())),
1132            (
1133                "colors",
1134                AttributeValue::SS(vec!["red".into(), "blue".into(), "green".into()]),
1135            ),
1136        ]);
1137        let av = vals(&[(
1138            ":remove",
1139            AttributeValue::SS(vec!["blue".into(), "green".into()]),
1140        )]);
1141        let no_names = None;
1142        let tracker = make_tracker(&no_names, &av);
1143        apply(&mut item, &expr, &tracker).unwrap();
1144        if let AttributeValue::SS(set) = &item["colors"] {
1145            assert_eq!(set, &vec!["red".to_string()]);
1146        } else {
1147            panic!("Expected SS");
1148        }
1149    }
1150
1151    #[test]
1152    fn test_combined_set_remove() {
1153        let expr = parse("SET label = :name REMOVE old_attr").unwrap();
1154        assert_eq!(expr.set_actions.len(), 1);
1155        assert_eq!(expr.remove_actions.len(), 1);
1156    }
1157
1158    #[test]
1159    fn test_duplicate_clause_error() {
1160        let result = parse("SET a = :v SET b = :w");
1161        assert!(result.is_err());
1162        assert!(result.unwrap_err().contains("only be used once"));
1163    }
1164
1165    /// #35(a): a later SET reads the pre-update value of an earlier target.
1166    #[test]
1167    fn test_set_reads_pre_update_snapshot() {
1168        let expr = parse("SET a = :v, b = a").unwrap();
1169        let mut item = make_item(&[
1170            ("pk", AttributeValue::S("k".into())),
1171            ("a", AttributeValue::S("OLD".into())),
1172        ]);
1173        let av = vals(&[(":v", AttributeValue::S("NEW".into()))]);
1174        let no_names = None;
1175        let tracker = make_tracker(&no_names, &av);
1176        apply(&mut item, &expr, &tracker).unwrap();
1177        assert_eq!(item["a"], AttributeValue::S("NEW".into()));
1178        assert_eq!(item["b"], AttributeValue::S("OLD".into()));
1179    }
1180
1181    /// #35(b): a parenthesised arithmetic group parses and evaluates.
1182    #[test]
1183    fn test_set_parenthesised_arithmetic() {
1184        let expr = parse("SET c = (c - :v)").unwrap();
1185        let mut item = make_item(&[
1186            ("pk", AttributeValue::S("k".into())),
1187            ("c", AttributeValue::N("10".into())),
1188        ]);
1189        let av = vals(&[(":v", AttributeValue::N("3".into()))]);
1190        let no_names = None;
1191        let tracker = make_tracker(&no_names, &av);
1192        apply(&mut item, &expr, &tracker).unwrap();
1193        assert_eq!(item["c"], AttributeValue::N("7".into()));
1194    }
1195
1196    /// #35(b): high-precision arithmetic inside a group stays exact (BigDecimal path).
1197    #[test]
1198    fn test_set_parenthesised_arithmetic_bigdecimal() {
1199        let expr = parse("SET c = (c + :v)").unwrap();
1200        let mut item = make_item(&[
1201            ("pk", AttributeValue::S("k".into())),
1202            ("c", AttributeValue::N("100000000000000000000".into())),
1203        ]);
1204        let av = vals(&[(":v", AttributeValue::N("1".into()))]);
1205        let no_names = None;
1206        let tracker = make_tracker(&no_names, &av);
1207        apply(&mut item, &expr, &tracker).unwrap();
1208        assert_eq!(item["c"], AttributeValue::N("100000000000000000001".into()));
1209    }
1210}