Skip to main content

dynoxide/expressions/
update.rs

1//! UpdateExpression parsing and evaluation.
2//!
3//! Supports SET, REMOVE, ADD, DELETE clauses.
4
5use crate::expressions::condition::parse_raw_path;
6use crate::expressions::tokenizer::{Token, TokenStream, tokenize};
7use crate::expressions::{
8    PathElement, TrackedExpressionAttributes, remove_path, resolve_path, resolve_path_elements,
9    set_path,
10};
11use crate::types::AttributeValue;
12use std::collections::HashMap;
13
14/// Parsed update expression with all clause actions.
15#[derive(Debug)]
16pub struct UpdateExpr {
17    pub set_actions: Vec<SetAction>,
18    pub remove_actions: Vec<Vec<PathElement>>,
19    pub add_actions: Vec<AddAction>,
20    pub delete_actions: Vec<DeleteAction>,
21}
22
23/// A SET action: `path = value_expr`
24#[derive(Debug)]
25pub struct SetAction {
26    pub path: Vec<PathElement>,
27    pub value: SetValue,
28}
29
30/// Value expression for SET.
31#[derive(Debug)]
32pub enum SetValue {
33    /// Direct value or path reference
34    Operand(SetOperand),
35    /// `operand + operand`
36    Plus(SetOperand, SetOperand),
37    /// `operand - operand`
38    Minus(SetOperand, SetOperand),
39}
40
41/// An operand in a SET expression.
42#[derive(Debug)]
43pub enum SetOperand {
44    Path(Vec<PathElement>),
45    ValueRef(String),
46    IfNotExists(Vec<PathElement>, Box<SetOperand>),
47    ListAppend(Box<SetOperand>, Box<SetOperand>),
48}
49
50/// An ADD action: `path :value`
51#[derive(Debug)]
52pub struct AddAction {
53    pub path: Vec<PathElement>,
54    pub value_ref: String,
55}
56
57/// A DELETE action: `path :value`
58#[derive(Debug)]
59pub struct DeleteAction {
60    pub path: Vec<PathElement>,
61    pub value_ref: String,
62}
63
64/// Parse an UpdateExpression string.
65pub fn parse(expr: &str) -> Result<UpdateExpr, String> {
66    let tokens =
67        tokenize(expr).map_err(|e| format!("Invalid UpdateExpression: Syntax error; {e}"))?;
68    let mut stream = TokenStream::new(tokens);
69
70    let mut set_actions = Vec::new();
71    let mut remove_actions = Vec::new();
72    let mut add_actions = Vec::new();
73    let mut delete_actions = Vec::new();
74
75    let mut seen_set = false;
76    let mut seen_remove = false;
77    let mut seen_add = false;
78    let mut seen_delete = false;
79
80    while !stream.at_end() {
81        match stream.peek() {
82            Some(Token::Set) => {
83                if seen_set {
84                    return Err("Invalid UpdateExpression: The \"SET\" section can only be used once in an update expression;".to_string());
85                }
86                seen_set = true;
87                stream.next();
88                parse_set_clause(&mut stream, &mut set_actions).map_err(wrap_syntax_error)?;
89            }
90            Some(Token::Remove) => {
91                if seen_remove {
92                    return Err("Invalid UpdateExpression: The \"REMOVE\" section can only be used once in an update expression;".to_string());
93                }
94                seen_remove = true;
95                stream.next();
96                parse_remove_clause(&mut stream, &mut remove_actions).map_err(wrap_syntax_error)?;
97            }
98            Some(Token::Add) => {
99                if seen_add {
100                    return Err("Invalid UpdateExpression: The \"ADD\" section can only be used once in an update expression;".to_string());
101                }
102                seen_add = true;
103                stream.next();
104                parse_add_clause(&mut stream, &mut add_actions).map_err(wrap_syntax_error)?;
105            }
106            Some(Token::Delete) => {
107                if seen_delete {
108                    return Err("Invalid UpdateExpression: The \"DELETE\" section can only be used once in an update expression;".to_string());
109                }
110                seen_delete = true;
111                stream.next();
112                parse_delete_clause(&mut stream, &mut delete_actions).map_err(wrap_syntax_error)?;
113            }
114            Some(t) => {
115                return Err(format!(
116                    "Invalid UpdateExpression: Syntax error; token: {t}"
117                ));
118            }
119            None => break,
120        }
121    }
122
123    Ok(UpdateExpr {
124        set_actions,
125        remove_actions,
126        add_actions,
127        delete_actions,
128    })
129}
130
131/// Wrap a sub-parser error with the standard syntax error prefix,
132/// unless it already has a recognised higher-level prefix.
133fn wrap_syntax_error(err: String) -> String {
134    if err.starts_with("Invalid UpdateExpression:") {
135        err
136    } else if err.starts_with("Attribute name is a reserved keyword") {
137        format!("Invalid UpdateExpression: {err}")
138    } else {
139        format!("Invalid UpdateExpression: Syntax error; {err}")
140    }
141}
142
143/// Walk an UpdateExpr and track all attribute name and value references
144/// without actually evaluating or modifying any item. This is used for
145/// pre-validation: checking that all referenced names/values are defined,
146/// and detecting unused names/values.
147pub fn track_references(
148    expr: &UpdateExpr,
149    tracker: &TrackedExpressionAttributes,
150) -> Result<(), String> {
151    // Collect all target paths for overlap/conflict detection
152    let mut all_target_paths: Vec<Vec<PathElement>> = Vec::new();
153
154    for action in &expr.set_actions {
155        track_path_refs(&action.path, tracker)?;
156        track_set_value_refs(&action.value, tracker)?;
157        all_target_paths.push(resolve_tracked_path(&action.path, tracker));
158    }
159    for path in &expr.remove_actions {
160        track_path_refs(path, tracker)?;
161        all_target_paths.push(resolve_tracked_path(path, tracker));
162    }
163    for action in &expr.add_actions {
164        track_path_refs(&action.path, tracker)?;
165        let val = tracker.resolve_value(&action.value_ref)?;
166        // Validate ADD operand type statically
167        validate_add_type(val)?;
168        all_target_paths.push(resolve_tracked_path(&action.path, tracker));
169    }
170    for action in &expr.delete_actions {
171        track_path_refs(&action.path, tracker)?;
172        let val = tracker.resolve_value(&action.value_ref)?;
173        // Validate DELETE operand type statically
174        validate_delete_type(val)?;
175        all_target_paths.push(resolve_tracked_path(&action.path, tracker));
176    }
177
178    // Static type validation for SET value expressions
179    for action in &expr.set_actions {
180        validate_set_value_types(&action.value, tracker)?;
181    }
182
183    // Check for overlapping/conflicting paths
184    check_path_overlaps(&all_target_paths)?;
185
186    Ok(())
187}
188
189/// Validate that an ADD operand has a compatible type.
190fn validate_add_type(val: &crate::types::AttributeValue) -> Result<(), String> {
191    use crate::types::AttributeValue;
192    match val {
193        AttributeValue::N(_)
194        | AttributeValue::SS(_)
195        | AttributeValue::NS(_)
196        | AttributeValue::BS(_) => Ok(()),
197        _ => Err(format!(
198            "Invalid UpdateExpression: Incorrect operand type for operator or function; \
199             operator: ADD, operand type: {}",
200            dynamo_type_name(val)
201        )),
202    }
203}
204
205/// Validate that a DELETE operand has a compatible type.
206fn validate_delete_type(val: &crate::types::AttributeValue) -> Result<(), String> {
207    use crate::types::AttributeValue;
208    match val {
209        AttributeValue::SS(_) | AttributeValue::NS(_) | AttributeValue::BS(_) => Ok(()),
210        _ => Err(format!(
211            "Invalid UpdateExpression: Incorrect operand type for operator or function; \
212             operator: DELETE, operand type: {}",
213            dynamo_type_name(val)
214        )),
215    }
216}
217
218/// Map an AttributeValue to its DynamoDB type name for error messages.
219fn dynamo_type_name(val: &crate::types::AttributeValue) -> &'static str {
220    use crate::types::AttributeValue;
221    match val {
222        AttributeValue::S(_) => "STRING",
223        AttributeValue::N(_) => "NUMBER",
224        AttributeValue::B(_) => "BINARY",
225        AttributeValue::BOOL(_) => "BOOLEAN",
226        AttributeValue::NULL(_) => "NULL",
227        AttributeValue::SS(_) => "SS",
228        AttributeValue::NS(_) => "NS",
229        AttributeValue::BS(_) => "BS",
230        AttributeValue::L(_) => "LIST",
231        AttributeValue::M(_) => "MAP",
232    }
233}
234
235/// Validate types for SET value expressions (arithmetic, list_append).
236fn validate_set_value_types(
237    value: &SetValue,
238    tracker: &TrackedExpressionAttributes,
239) -> Result<(), String> {
240    match value {
241        SetValue::Operand(op) => validate_set_operand_types(op, tracker),
242        SetValue::Plus(left, right) => {
243            validate_arithmetic_operand(left, "+", tracker)?;
244            validate_arithmetic_operand(right, "+", tracker)
245        }
246        SetValue::Minus(left, right) => {
247            validate_arithmetic_operand(left, "-", tracker)?;
248            validate_arithmetic_operand(right, "-", tracker)
249        }
250    }
251}
252
253/// Validate that an operand used in + or - is a number (if it's a value ref).
254fn validate_arithmetic_operand(
255    operand: &SetOperand,
256    op: &str,
257    tracker: &TrackedExpressionAttributes,
258) -> Result<(), String> {
259    use crate::types::AttributeValue;
260    match operand {
261        SetOperand::ValueRef(name) => {
262            let val = tracker.resolve_value(name)?;
263            if !matches!(val, AttributeValue::N(_)) {
264                return Err(format!(
265                    "Invalid UpdateExpression: Incorrect operand type for operator or function; \
266                     operator or function: {op}, operand type: {}",
267                    dynamo_type_name(val)
268                ));
269            }
270            Ok(())
271        }
272        SetOperand::IfNotExists(_, default) => validate_set_operand_types(default, tracker),
273        SetOperand::ListAppend(a, b) => {
274            validate_list_append_operand(a, tracker)?;
275            validate_list_append_operand(b, tracker)
276        }
277        SetOperand::Path(_) => Ok(()), // Path types checked at runtime
278    }
279}
280
281/// Validate types for a set operand (recursively).
282fn validate_set_operand_types(
283    operand: &SetOperand,
284    tracker: &TrackedExpressionAttributes,
285) -> Result<(), String> {
286    match operand {
287        SetOperand::ListAppend(a, b) => {
288            validate_list_append_operand(a, tracker)?;
289            validate_list_append_operand(b, tracker)
290        }
291        SetOperand::IfNotExists(_, default) => validate_set_operand_types(default, tracker),
292        _ => Ok(()),
293    }
294}
295
296/// Validate a list_append operand is a list if it's a value ref.
297fn validate_list_append_operand(
298    operand: &SetOperand,
299    tracker: &TrackedExpressionAttributes,
300) -> Result<(), String> {
301    use crate::types::AttributeValue;
302    if let SetOperand::ValueRef(name) = operand {
303        let val = tracker.resolve_value(name)?;
304        if !matches!(val, AttributeValue::L(_)) {
305            return Err(format!(
306                "Invalid UpdateExpression: Incorrect operand type for operator or function; \
307                 operator or function: list_append, operand type: {}",
308                dynamo_type_name(val)
309            ));
310        }
311    }
312    Ok(())
313}
314
315/// Resolve path elements to their final names (expanding #name refs).
316fn resolve_tracked_path(
317    path: &[PathElement],
318    tracker: &TrackedExpressionAttributes,
319) -> Vec<PathElement> {
320    path.iter()
321        .map(|elem| {
322            if let PathElement::Attribute(name) = elem {
323                if name.starts_with('#') {
324                    if let Ok(resolved) = tracker.resolve_name(name) {
325                        return PathElement::Attribute(resolved);
326                    }
327                }
328            }
329            elem.clone()
330        })
331        .collect()
332}
333
334/// Format a path for error messages in dynalite format: [a, b, [1], c].
335fn format_path_for_error(path: &[PathElement]) -> String {
336    let parts: Vec<String> = path
337        .iter()
338        .map(|elem| match elem {
339            PathElement::Attribute(name) => name.clone(),
340            PathElement::Index(i) => format!("[{i}]"),
341        })
342        .collect();
343    format!("[{}]", parts.join(", "))
344}
345
346/// Check for overlapping or conflicting document paths.
347///
348/// Two paths overlap if one is a prefix of the other (e.g., `a.b` and `a.b.c`).
349/// Two paths conflict if they share elements but diverge in type at the same
350/// position (e.g., `a[3].c` and `a.c[3]`).
351fn check_path_overlaps(paths: &[Vec<PathElement>]) -> Result<(), String> {
352    for i in 0..paths.len() {
353        for j in (i + 1)..paths.len() {
354            let a = &paths[i];
355            let b = &paths[j];
356            let min_len = a.len().min(b.len());
357
358            // Check common prefix length
359            let mut common = 0;
360            for k in 0..min_len {
361                if a[k] == b[k] {
362                    common += 1;
363                } else {
364                    break;
365                }
366            }
367
368            if common == 0 {
369                continue;
370            }
371
372            // If one path is a prefix of the other, they overlap
373            if common == a.len() || common == b.len() {
374                let (shorter, longer) = if a.len() <= b.len() { (a, b) } else { (b, a) };
375                return Err(format!(
376                    "Invalid UpdateExpression: Two document paths overlap with each other; \
377                     must remove or rewrite one of these paths; \
378                     path one: {}, path two: {}",
379                    format_path_for_error(longer),
380                    format_path_for_error(shorter)
381                ));
382            }
383
384            // If paths share a prefix but diverge, they conflict
385            if common > 0 && common < min_len && a == b {
386                return Err(format!(
387                    "Invalid UpdateExpression: Two document paths conflict with each other; \
388                     must remove or rewrite one of these paths; \
389                     path one: {}, path two: {}",
390                    format_path_for_error(a),
391                    format_path_for_error(b)
392                ));
393            }
394        }
395    }
396    Ok(())
397}
398
399fn track_path_refs(
400    path: &[PathElement],
401    tracker: &TrackedExpressionAttributes,
402) -> Result<(), String> {
403    for elem in path {
404        if let PathElement::Attribute(name) = elem {
405            if name.starts_with('#') {
406                tracker.resolve_name(name)?;
407            }
408        }
409    }
410    Ok(())
411}
412
413fn track_set_value_refs(
414    value: &SetValue,
415    tracker: &TrackedExpressionAttributes,
416) -> Result<(), String> {
417    match value {
418        SetValue::Operand(op) => track_set_operand_refs(op, tracker),
419        SetValue::Plus(left, right) | SetValue::Minus(left, right) => {
420            track_set_operand_refs(left, tracker)?;
421            track_set_operand_refs(right, tracker)
422        }
423    }
424}
425
426fn track_set_operand_refs(
427    operand: &SetOperand,
428    tracker: &TrackedExpressionAttributes,
429) -> Result<(), String> {
430    match operand {
431        SetOperand::Path(path) => track_path_refs(path, tracker),
432        SetOperand::ValueRef(name) => {
433            tracker.resolve_value(name)?;
434            Ok(())
435        }
436        SetOperand::IfNotExists(path, default) => {
437            track_path_refs(path, tracker)?;
438            track_set_operand_refs(default, tracker)
439        }
440        SetOperand::ListAppend(a, b) => {
441            track_set_operand_refs(a, tracker)?;
442            track_set_operand_refs(b, tracker)
443        }
444    }
445}
446
447/// Apply an update expression to an item (mutating it in place), tracking attribute usage.
448pub fn apply(
449    item: &mut HashMap<String, AttributeValue>,
450    expr: &UpdateExpr,
451    tracker: &TrackedExpressionAttributes,
452) -> Result<(), String> {
453    // Process SET actions
454    for action in &expr.set_actions {
455        let resolved_path = resolve_path_elements(&action.path, tracker)?;
456        let value = evaluate_set_value(&action.value, item, tracker)?;
457        set_path(item, &resolved_path, value)?;
458    }
459
460    // Process REMOVE actions
461    for path in &expr.remove_actions {
462        let resolved_path = resolve_path_elements(path, tracker)?;
463        remove_path(item, &resolved_path)?;
464    }
465
466    // Process ADD actions
467    for action in &expr.add_actions {
468        let resolved_path = resolve_path_elements(&action.path, tracker)?;
469        let add_val = tracker.resolve_value(&action.value_ref)?.clone();
470        apply_add(item, &resolved_path, &add_val).map_err(|_| {
471            "An operand in the update expression has an incorrect data type".to_string()
472        })?;
473    }
474
475    // Process DELETE actions
476    for action in &expr.delete_actions {
477        let resolved_path = resolve_path_elements(&action.path, tracker)?;
478        let del_val = tracker.resolve_value(&action.value_ref)?.clone();
479        apply_delete(item, &resolved_path, &del_val).map_err(|_| {
480            "An operand in the update expression has an incorrect data type".to_string()
481        })?;
482    }
483
484    Ok(())
485}
486
487// ---------------------------------------------------------------------------
488// SET value evaluation
489// ---------------------------------------------------------------------------
490
491fn evaluate_set_value(
492    value: &SetValue,
493    item: &HashMap<String, AttributeValue>,
494    tracker: &TrackedExpressionAttributes,
495) -> Result<AttributeValue, String> {
496    match value {
497        SetValue::Operand(op) => evaluate_set_operand(op, item, tracker),
498        SetValue::Plus(left, right) => {
499            let lv = evaluate_set_operand(left, item, tracker)?;
500            let rv = evaluate_set_operand(right, item, tracker)?;
501            match (&lv, &rv) {
502                (AttributeValue::N(a), AttributeValue::N(b)) => {
503                    use bigdecimal::BigDecimal;
504                    use std::str::FromStr;
505                    let da = BigDecimal::from_str(a).map_err(|_| format!("Invalid number: {a}"))?;
506                    let db = BigDecimal::from_str(b).map_err(|_| format!("Invalid number: {b}"))?;
507                    let result = &da + &db;
508                    Ok(AttributeValue::N(format_number(&result)))
509                }
510                _ => Err("Operands for + must be numbers".to_string()),
511            }
512        }
513        SetValue::Minus(left, right) => {
514            let lv = evaluate_set_operand(left, item, tracker)?;
515            let rv = evaluate_set_operand(right, item, tracker)?;
516            match (&lv, &rv) {
517                (AttributeValue::N(a), AttributeValue::N(b)) => {
518                    use bigdecimal::BigDecimal;
519                    use std::str::FromStr;
520                    let da = BigDecimal::from_str(a).map_err(|_| format!("Invalid number: {a}"))?;
521                    let db = BigDecimal::from_str(b).map_err(|_| format!("Invalid number: {b}"))?;
522                    let result = &da - &db;
523                    Ok(AttributeValue::N(format_number(&result)))
524                }
525                _ => Err("Operands for - must be numbers".to_string()),
526            }
527        }
528    }
529}
530
531fn evaluate_set_operand(
532    operand: &SetOperand,
533    item: &HashMap<String, AttributeValue>,
534    tracker: &TrackedExpressionAttributes,
535) -> Result<AttributeValue, String> {
536    match operand {
537        SetOperand::Path(path) => {
538            let resolved = resolve_path_elements(path, tracker)?;
539            resolve_path(item, &resolved).ok_or_else(|| {
540                "The provided expression refers to an attribute that does not exist in the item"
541                    .to_string()
542            })
543        }
544        SetOperand::ValueRef(name) => Ok(tracker.resolve_value(name)?.clone()),
545        SetOperand::IfNotExists(path, default) => {
546            let resolved = resolve_path_elements(path, tracker)?;
547            match resolve_path(item, &resolved) {
548                Some(existing) => Ok(existing),
549                None => evaluate_set_operand(default, item, tracker),
550            }
551        }
552        SetOperand::ListAppend(list1, list2) => {
553            let v1 = evaluate_set_operand(list1, item, tracker)?;
554            let v2 = evaluate_set_operand(list2, item, tracker)?;
555            match (v1, v2) {
556                (AttributeValue::L(mut a), AttributeValue::L(b)) => {
557                    a.extend(b);
558                    Ok(AttributeValue::L(a))
559                }
560                _ => Err("list_append requires two list operands".to_string()),
561            }
562        }
563    }
564}
565
566// ---------------------------------------------------------------------------
567// ADD action
568// ---------------------------------------------------------------------------
569
570/// Public wrapper for use by legacy `AttributeUpdates` support.
571pub fn apply_add_public(
572    item: &mut HashMap<String, AttributeValue>,
573    path: &[PathElement],
574    add_val: &AttributeValue,
575) -> Result<(), String> {
576    apply_add(item, path, add_val)
577}
578
579fn apply_add(
580    item: &mut HashMap<String, AttributeValue>,
581    path: &[PathElement],
582    add_val: &AttributeValue,
583) -> Result<(), String> {
584    let existing = resolve_path(item, path);
585
586    match (existing, add_val) {
587        // Number: add to existing number or create
588        (Some(AttributeValue::N(existing_n)), AttributeValue::N(add_n)) => {
589            use bigdecimal::BigDecimal;
590            use std::str::FromStr;
591            let de = BigDecimal::from_str(&existing_n)
592                .map_err(|_| format!("Invalid number: {existing_n}"))?;
593            let da = BigDecimal::from_str(add_n).map_err(|_| format!("Invalid number: {add_n}"))?;
594            let result = &de + &da;
595            set_path(item, path, AttributeValue::N(format_number(&result)))
596        }
597        (None, AttributeValue::N(_)) => {
598            // Create with the provided value
599            set_path(item, path, add_val.clone())
600        }
601
602        // String set: union
603        (Some(AttributeValue::SS(mut existing_set)), AttributeValue::SS(add_set)) => {
604            for s in add_set {
605                if !existing_set.contains(s) {
606                    existing_set.push(s.clone());
607                }
608            }
609            set_path(item, path, AttributeValue::SS(existing_set))
610        }
611        (None, AttributeValue::SS(_)) => set_path(item, path, add_val.clone()),
612
613        // Number set: union
614        (Some(AttributeValue::NS(mut existing_set)), AttributeValue::NS(add_set)) => {
615            for n in add_set {
616                if !existing_set.contains(n) {
617                    existing_set.push(n.clone());
618                }
619            }
620            set_path(item, path, AttributeValue::NS(existing_set))
621        }
622        (None, AttributeValue::NS(_)) => set_path(item, path, add_val.clone()),
623
624        // Binary set: union
625        (Some(AttributeValue::BS(mut existing_set)), AttributeValue::BS(add_set)) => {
626            for b in add_set {
627                if !existing_set.contains(b) {
628                    existing_set.push(b.clone());
629                }
630            }
631            set_path(item, path, AttributeValue::BS(existing_set))
632        }
633        (None, AttributeValue::BS(_)) => set_path(item, path, add_val.clone()),
634
635        // List: append elements (legacy AttributeUpdates behaviour)
636        (Some(AttributeValue::L(mut existing_list)), AttributeValue::L(add_list)) => {
637            existing_list.extend(add_list.iter().cloned());
638            set_path(item, path, AttributeValue::L(existing_list))
639        }
640        (None, AttributeValue::L(_)) => set_path(item, path, add_val.clone()),
641
642        _ => Err("Type mismatch for attribute to update".to_string()),
643    }
644}
645
646// ---------------------------------------------------------------------------
647// DELETE action
648// ---------------------------------------------------------------------------
649
650/// Public wrapper for use by legacy `AttributeUpdates` support.
651pub fn apply_delete_public(
652    item: &mut HashMap<String, AttributeValue>,
653    path: &[PathElement],
654    del_val: &AttributeValue,
655) -> Result<(), String> {
656    apply_delete(item, path, del_val)
657}
658
659fn apply_delete(
660    item: &mut HashMap<String, AttributeValue>,
661    path: &[PathElement],
662    del_val: &AttributeValue,
663) -> Result<(), String> {
664    let existing = resolve_path(item, path);
665
666    match (existing, del_val) {
667        (Some(AttributeValue::SS(existing_set)), AttributeValue::SS(del_set)) => {
668            let new_set: Vec<String> = existing_set
669                .into_iter()
670                .filter(|s| !del_set.contains(s))
671                .collect();
672            if new_set.is_empty() {
673                remove_path(item, path)
674            } else {
675                set_path(item, path, AttributeValue::SS(new_set))
676            }
677        }
678        (Some(AttributeValue::NS(existing_set)), AttributeValue::NS(del_set)) => {
679            let new_set: Vec<String> = existing_set
680                .into_iter()
681                .filter(|n| !del_set.contains(n))
682                .collect();
683            if new_set.is_empty() {
684                remove_path(item, path)
685            } else {
686                set_path(item, path, AttributeValue::NS(new_set))
687            }
688        }
689        (Some(AttributeValue::BS(existing_set)), AttributeValue::BS(del_set)) => {
690            let new_set: Vec<Vec<u8>> = existing_set
691                .into_iter()
692                .filter(|b| !del_set.contains(b))
693                .collect();
694            if new_set.is_empty() {
695                remove_path(item, path)
696            } else {
697                set_path(item, path, AttributeValue::BS(new_set))
698            }
699        }
700        (None, _) => Ok(()), // Nothing to delete from
701        _ => Err("Type mismatch for attribute to update".to_string()),
702    }
703}
704
705// ---------------------------------------------------------------------------
706// Parser
707// ---------------------------------------------------------------------------
708
709fn parse_set_clause(stream: &mut TokenStream, actions: &mut Vec<SetAction>) -> Result<(), String> {
710    actions.push(parse_set_action(stream)?);
711    while matches!(stream.peek(), Some(Token::Comma)) {
712        stream.next();
713        actions.push(parse_set_action(stream)?);
714    }
715    Ok(())
716}
717
718fn parse_set_action(stream: &mut TokenStream) -> Result<SetAction, String> {
719    let path = parse_raw_path(stream)?;
720    stream.expect(&Token::Eq)?;
721    let value = parse_set_value(stream)?;
722    Ok(SetAction { path, value })
723}
724
725fn parse_set_value(stream: &mut TokenStream) -> Result<SetValue, String> {
726    let left = parse_set_operand(stream)?;
727
728    match stream.peek() {
729        Some(Token::Plus) => {
730            stream.next();
731            let right = parse_set_operand(stream)?;
732            Ok(SetValue::Plus(left, right))
733        }
734        Some(Token::Minus) => {
735            stream.next();
736            let right = parse_set_operand(stream)?;
737            Ok(SetValue::Minus(left, right))
738        }
739        _ => Ok(SetValue::Operand(left)),
740    }
741}
742
743fn parse_set_operand(stream: &mut TokenStream) -> Result<SetOperand, String> {
744    // Check for functions: if_not_exists, list_append
745    if let Some(Token::Identifier(name)) = stream.peek() {
746        let func_name = name.to_lowercase();
747        let orig_name = name.clone();
748        match func_name.as_str() {
749            "if_not_exists" => {
750                stream.next();
751                stream.expect(&Token::LParen)?;
752
753                // First argument must be a document path (not a value ref or function)
754                match stream.peek() {
755                    Some(Token::ValueRef(_)) => {
756                        return Err(
757                            "Invalid UpdateExpression: Operator or function requires a document path; \
758                             operator or function: if_not_exists".to_string()
759                        );
760                    }
761                    Some(Token::Identifier(fname))
762                        if fname.to_lowercase() == "if_not_exists"
763                            || fname.to_lowercase() == "list_append" =>
764                    {
765                        return Err(
766                            "Invalid UpdateExpression: Operator or function requires a document path; \
767                             operator or function: if_not_exists".to_string()
768                        );
769                    }
770                    _ => {}
771                }
772
773                let path = parse_raw_path(stream)?;
774
775                // Check for correct number of operands
776                if !matches!(stream.peek(), Some(Token::Comma)) {
777                    return Err(
778                        "Invalid UpdateExpression: Incorrect number of operands for operator or function; \
779                         operator or function: if_not_exists, number of operands: 1".to_string()
780                    );
781                }
782                stream.expect(&Token::Comma)?;
783                let default = parse_set_operand(stream)?;
784                stream.expect(&Token::RParen)?;
785                return Ok(SetOperand::IfNotExists(path, Box::new(default)));
786            }
787            "list_append" => {
788                stream.next();
789                stream.expect(&Token::LParen)?;
790                let list1 = parse_set_operand(stream)?;
791
792                // Check for correct number of operands
793                if !matches!(stream.peek(), Some(Token::Comma)) {
794                    return Err(
795                        "Invalid UpdateExpression: Incorrect number of operands for operator or function; \
796                         operator or function: list_append, number of operands: 1".to_string()
797                    );
798                }
799                stream.expect(&Token::Comma)?;
800                let list2 = parse_set_operand(stream)?;
801                stream.expect(&Token::RParen)?;
802                return Ok(SetOperand::ListAppend(Box::new(list1), Box::new(list2)));
803            }
804            _ => {
805                // Check if this looks like a function call (identifier followed by '(')
806                // If so, report "Invalid function name" for unknown functions.
807                let saved_pos = stream.pos();
808                stream.next();
809                if matches!(stream.peek(), Some(Token::LParen)) {
810                    return Err(format!(
811                        "Invalid UpdateExpression: Invalid function name; function: {}",
812                        orig_name
813                    ));
814                }
815                // Rewind — not a function call, treat as path
816                stream.set_pos(saved_pos);
817            }
818        }
819    }
820
821    match stream.peek() {
822        Some(Token::ValueRef(_)) => {
823            if let Some(Token::ValueRef(name)) = stream.next().cloned() {
824                Ok(SetOperand::ValueRef(name))
825            } else {
826                unreachable!()
827            }
828        }
829        Some(Token::Identifier(_)) | Some(Token::NameRef(_)) => {
830            let path = parse_raw_path(stream)?;
831            Ok(SetOperand::Path(path))
832        }
833        Some(t) => Err(format!("Expected operand in SET, got {t}")),
834        None => Err("Expected operand in SET, got end of expression".to_string()),
835    }
836}
837
838fn parse_remove_clause(
839    stream: &mut TokenStream,
840    actions: &mut Vec<Vec<PathElement>>,
841) -> Result<(), String> {
842    actions.push(parse_raw_path(stream)?);
843    while matches!(stream.peek(), Some(Token::Comma)) {
844        stream.next();
845        actions.push(parse_raw_path(stream)?);
846    }
847    Ok(())
848}
849
850fn parse_add_clause(stream: &mut TokenStream, actions: &mut Vec<AddAction>) -> Result<(), String> {
851    actions.push(parse_add_action(stream)?);
852    while matches!(stream.peek(), Some(Token::Comma)) {
853        stream.next();
854        actions.push(parse_add_action(stream)?);
855    }
856    Ok(())
857}
858
859fn parse_add_action(stream: &mut TokenStream) -> Result<AddAction, String> {
860    let path = parse_raw_path(stream)?;
861    match stream.next() {
862        Some(Token::ValueRef(name)) => Ok(AddAction {
863            path,
864            value_ref: name.clone(),
865        }),
866        Some(t) => Err(format!("Expected value reference in ADD, got {t}")),
867        None => Err("Expected value reference in ADD, got end of expression".to_string()),
868    }
869}
870
871fn parse_delete_clause(
872    stream: &mut TokenStream,
873    actions: &mut Vec<DeleteAction>,
874) -> Result<(), String> {
875    actions.push(parse_delete_action(stream)?);
876    while matches!(stream.peek(), Some(Token::Comma)) {
877        stream.next();
878        actions.push(parse_delete_action(stream)?);
879    }
880    Ok(())
881}
882
883fn parse_delete_action(stream: &mut TokenStream) -> Result<DeleteAction, String> {
884    let path = parse_raw_path(stream)?;
885    match stream.next() {
886        Some(Token::ValueRef(name)) => Ok(DeleteAction {
887            path,
888            value_ref: name.clone(),
889        }),
890        Some(t) => Err(format!("Expected value reference in DELETE, got {t}")),
891        None => Err("Expected value reference in DELETE, got end of expression".to_string()),
892    }
893}
894
895/// Format a BigDecimal number, stripping unnecessary trailing zeros.
896/// DynamoDB returns numbers without scientific notation.
897fn format_number(n: &bigdecimal::BigDecimal) -> String {
898    let normalized = n.normalized();
899    // Force scale >= 0 so BigDecimal renders without scientific notation.
900    // When the exponent is negative (large integer like 1e38), with_scale(0)
901    // expands to full decimal digits.
902    if normalized.as_bigint_and_exponent().1 < 0 {
903        normalized.with_scale(0).to_string()
904    } else {
905        normalized.to_string()
906    }
907}
908
909#[cfg(test)]
910mod tests {
911    use super::*;
912
913    fn make_item(pairs: &[(&str, AttributeValue)]) -> HashMap<String, AttributeValue> {
914        pairs
915            .iter()
916            .map(|(k, v)| (k.to_string(), v.clone()))
917            .collect()
918    }
919
920    fn vals(pairs: &[(&str, AttributeValue)]) -> Option<HashMap<String, AttributeValue>> {
921        Some(make_item(pairs))
922    }
923
924    fn make_tracker<'a>(
925        names: &'a Option<HashMap<String, String>>,
926        values: &'a Option<HashMap<String, AttributeValue>>,
927    ) -> TrackedExpressionAttributes<'a> {
928        TrackedExpressionAttributes::new(names, values)
929    }
930
931    #[test]
932    fn test_set_simple() {
933        let expr = parse("SET label = :val").unwrap();
934        assert_eq!(expr.set_actions.len(), 1);
935        assert!(expr.remove_actions.is_empty());
936    }
937
938    #[test]
939    fn test_set_multiple() {
940        let expr = parse("SET a = :v1, b = :v2").unwrap();
941        assert_eq!(expr.set_actions.len(), 2);
942    }
943
944    #[test]
945    fn test_set_arithmetic_plus() {
946        let expr = parse("SET tally = tally + :inc").unwrap();
947        let mut item = make_item(&[
948            ("pk", AttributeValue::S("k".into())),
949            ("tally", AttributeValue::N("10".into())),
950        ]);
951        let av = vals(&[(":inc", AttributeValue::N("5".into()))]);
952        let no_names = None;
953        let tracker = make_tracker(&no_names, &av);
954        apply(&mut item, &expr, &tracker).unwrap();
955        assert_eq!(item["tally"], AttributeValue::N("15".into()));
956    }
957
958    #[test]
959    fn test_set_arithmetic_minus() {
960        let expr = parse("SET price = price - :discount").unwrap();
961        let mut item = make_item(&[
962            ("pk", AttributeValue::S("k".into())),
963            ("price", AttributeValue::N("100".into())),
964        ]);
965        let av = vals(&[(":discount", AttributeValue::N("25".into()))]);
966        let no_names = None;
967        let tracker = make_tracker(&no_names, &av);
968        apply(&mut item, &expr, &tracker).unwrap();
969        assert_eq!(item["price"], AttributeValue::N("75".into()));
970    }
971
972    #[test]
973    fn test_set_if_not_exists() {
974        let expr = parse("SET hits = if_not_exists(hits, :zero)").unwrap();
975        let mut item = make_item(&[("pk", AttributeValue::S("k".into()))]);
976        let av = vals(&[(":zero", AttributeValue::N("0".into()))]);
977        let no_names = None;
978        let tracker = make_tracker(&no_names, &av);
979        apply(&mut item, &expr, &tracker).unwrap();
980        assert_eq!(item["hits"], AttributeValue::N("0".into()));
981
982        // Apply again — existing value should be preserved
983        let tracker2 = make_tracker(&no_names, &av);
984        apply(&mut item, &expr, &tracker2).unwrap();
985        assert_eq!(item["hits"], AttributeValue::N("0".into()));
986    }
987
988    #[test]
989    fn test_set_list_append() {
990        let expr = parse("SET entries = list_append(entries, :new)").unwrap();
991        let mut item = make_item(&[
992            ("pk", AttributeValue::S("k".into())),
993            (
994                "entries",
995                AttributeValue::L(vec![AttributeValue::S("a".into())]),
996            ),
997        ]);
998        let av = vals(&[(
999            ":new",
1000            AttributeValue::L(vec![AttributeValue::S("b".into())]),
1001        )]);
1002        let no_names = None;
1003        let tracker = make_tracker(&no_names, &av);
1004        apply(&mut item, &expr, &tracker).unwrap();
1005        if let AttributeValue::L(list) = &item["entries"] {
1006            assert_eq!(list.len(), 2);
1007        } else {
1008            panic!("Expected list");
1009        }
1010    }
1011
1012    #[test]
1013    fn test_remove() {
1014        let expr = parse("REMOVE attr1, attr2").unwrap();
1015        let mut item = make_item(&[
1016            ("pk", AttributeValue::S("k".into())),
1017            ("attr1", AttributeValue::S("a".into())),
1018            ("attr2", AttributeValue::S("b".into())),
1019            ("attr3", AttributeValue::S("c".into())),
1020        ]);
1021        let no_names = None;
1022        let no_values = None;
1023        let tracker = make_tracker(&no_names, &no_values);
1024        apply(&mut item, &expr, &tracker).unwrap();
1025        assert!(!item.contains_key("attr1"));
1026        assert!(!item.contains_key("attr2"));
1027        assert!(item.contains_key("attr3"));
1028    }
1029
1030    #[test]
1031    fn test_add_number() {
1032        let expr = parse("ADD tally :inc").unwrap();
1033        let mut item = make_item(&[
1034            ("pk", AttributeValue::S("k".into())),
1035            ("tally", AttributeValue::N("10".into())),
1036        ]);
1037        let av = vals(&[(":inc", AttributeValue::N("5".into()))]);
1038        let no_names = None;
1039        let tracker = make_tracker(&no_names, &av);
1040        apply(&mut item, &expr, &tracker).unwrap();
1041        assert_eq!(item["tally"], AttributeValue::N("15".into()));
1042    }
1043
1044    #[test]
1045    fn test_add_number_create() {
1046        let expr = parse("ADD tally :val").unwrap();
1047        let mut item = make_item(&[("pk", AttributeValue::S("k".into()))]);
1048        let av = vals(&[(":val", AttributeValue::N("1".into()))]);
1049        let no_names = None;
1050        let tracker = make_tracker(&no_names, &av);
1051        apply(&mut item, &expr, &tracker).unwrap();
1052        assert_eq!(item["tally"], AttributeValue::N("1".into()));
1053    }
1054
1055    #[test]
1056    fn test_add_string_set() {
1057        let expr = parse("ADD colors :new_colors").unwrap();
1058        let mut item = make_item(&[
1059            ("pk", AttributeValue::S("k".into())),
1060            (
1061                "colors",
1062                AttributeValue::SS(vec!["red".into(), "blue".into()]),
1063            ),
1064        ]);
1065        let av = vals(&[(
1066            ":new_colors",
1067            AttributeValue::SS(vec!["blue".into(), "green".into()]),
1068        )]);
1069        let no_names = None;
1070        let tracker = make_tracker(&no_names, &av);
1071        apply(&mut item, &expr, &tracker).unwrap();
1072        if let AttributeValue::SS(set) = &item["colors"] {
1073            assert_eq!(set.len(), 3); // red, blue, green (blue deduplicated)
1074            assert!(set.contains(&"green".to_string()));
1075        } else {
1076            panic!("Expected SS");
1077        }
1078    }
1079
1080    #[test]
1081    fn test_delete_string_set() {
1082        let expr = parse("DELETE colors :remove").unwrap();
1083        let mut item = make_item(&[
1084            ("pk", AttributeValue::S("k".into())),
1085            (
1086                "colors",
1087                AttributeValue::SS(vec!["red".into(), "blue".into(), "green".into()]),
1088            ),
1089        ]);
1090        let av = vals(&[(
1091            ":remove",
1092            AttributeValue::SS(vec!["blue".into(), "green".into()]),
1093        )]);
1094        let no_names = None;
1095        let tracker = make_tracker(&no_names, &av);
1096        apply(&mut item, &expr, &tracker).unwrap();
1097        if let AttributeValue::SS(set) = &item["colors"] {
1098            assert_eq!(set, &vec!["red".to_string()]);
1099        } else {
1100            panic!("Expected SS");
1101        }
1102    }
1103
1104    #[test]
1105    fn test_combined_set_remove() {
1106        let expr = parse("SET label = :name REMOVE old_attr").unwrap();
1107        assert_eq!(expr.set_actions.len(), 1);
1108        assert_eq!(expr.remove_actions.len(), 1);
1109    }
1110
1111    #[test]
1112    fn test_duplicate_clause_error() {
1113        let result = parse("SET a = :v SET b = :w");
1114        assert!(result.is_err());
1115        assert!(result.unwrap_err().contains("only be used once"));
1116    }
1117}