Skip to main content

dynoxide/expressions/
update.rs

1//! UpdateExpression parsing and evaluation.
2//!
3//! Supports SET, REMOVE, ADD, DELETE clauses.
4
5use crate::expressions::condition::parse_raw_path;
6use crate::expressions::tokenizer::{
7    Token, TokenStream, near_window_parser, near_window_tokenizer, tokenize,
8};
9use crate::expressions::{
10    PathElement, TrackedExpressionAttributes, remove_path, resolve_path, resolve_path_elements,
11    set_path,
12};
13use crate::types::AttributeValue;
14use std::collections::HashMap;
15
16/// Parsed update expression with all clause actions.
17#[derive(Debug)]
18pub struct UpdateExpr {
19    pub set_actions: Vec<SetAction>,
20    pub remove_actions: Vec<Vec<PathElement>>,
21    pub add_actions: Vec<AddAction>,
22    pub delete_actions: Vec<DeleteAction>,
23}
24
25/// A SET action: `path = value_expr`
26#[derive(Debug)]
27pub struct SetAction {
28    pub path: Vec<PathElement>,
29    pub value: SetValue,
30}
31
32/// Value expression for SET.
33#[derive(Debug)]
34pub enum SetValue {
35    /// Direct value or path reference
36    Operand(SetOperand),
37    /// `operand + operand`
38    Plus(SetOperand, SetOperand),
39    /// `operand - operand`
40    Minus(SetOperand, SetOperand),
41}
42
43/// An operand in a SET expression.
44#[derive(Debug)]
45pub enum SetOperand {
46    Path(Vec<PathElement>),
47    ValueRef(String),
48    IfNotExists(Vec<PathElement>, Box<SetOperand>),
49    ListAppend(Box<SetOperand>, Box<SetOperand>),
50}
51
52/// An ADD action: `path :value`
53#[derive(Debug)]
54pub struct AddAction {
55    pub path: Vec<PathElement>,
56    pub value_ref: String,
57}
58
59/// A DELETE action: `path :value`
60#[derive(Debug)]
61pub struct DeleteAction {
62    pub path: Vec<PathElement>,
63    pub value_ref: String,
64}
65
66/// Parse an UpdateExpression string.
67pub fn parse(expr: &str) -> Result<UpdateExpr, String> {
68    let tokens = match tokenize(expr) {
69        Ok(t) => t,
70        Err(err) => {
71            // Tokenizer-level syntax error (e.g. stray `!` mid-expression):
72            // emit the same shape as parser-level errors, with a tokenizer-style
73            // near: window (offending byte plus at most one more non-whitespace byte).
74            let bad = &expr[err.position..err.position + err.bad_len];
75            let near = near_window_tokenizer(expr, err.position);
76            return Err(format!(
77                r#"Invalid UpdateExpression: Syntax error; token: "{bad}", near: "{near}""#
78            ));
79        }
80    };
81    let mut stream = TokenStream::new(tokens);
82
83    let mut set_actions = Vec::new();
84    let mut remove_actions = Vec::new();
85    let mut add_actions = Vec::new();
86    let mut delete_actions = Vec::new();
87
88    let mut seen_set = false;
89    let mut seen_remove = false;
90    let mut seen_add = false;
91    let mut seen_delete = false;
92
93    while !stream.at_end() {
94        match stream.peek() {
95            Some(Token::Set) => {
96                if seen_set {
97                    return Err("Invalid UpdateExpression: The \"SET\" section can only be used once in an update expression;".to_string());
98                }
99                seen_set = true;
100                stream.next();
101                parse_set_clause(&mut stream, &mut set_actions).map_err(wrap_syntax_error)?;
102            }
103            Some(Token::Remove) => {
104                if seen_remove {
105                    return Err("Invalid UpdateExpression: The \"REMOVE\" section can only be used once in an update expression;".to_string());
106                }
107                seen_remove = true;
108                stream.next();
109                parse_remove_clause(&mut stream, &mut remove_actions).map_err(wrap_syntax_error)?;
110            }
111            Some(Token::Add) => {
112                if seen_add {
113                    return Err("Invalid UpdateExpression: The \"ADD\" section can only be used once in an update expression;".to_string());
114                }
115                seen_add = true;
116                stream.next();
117                parse_add_clause(&mut stream, &mut add_actions).map_err(wrap_syntax_error)?;
118            }
119            Some(Token::Delete) => {
120                if seen_delete {
121                    return Err("Invalid UpdateExpression: The \"DELETE\" section can only be used once in an update expression;".to_string());
122                }
123                seen_delete = true;
124                stream.next();
125                parse_delete_clause(&mut stream, &mut delete_actions).map_err(wrap_syntax_error)?;
126            }
127            Some(_) => {
128                // Unexpected leading token where SET/REMOVE/ADD/DELETE was required.
129                // Build the AWS-style "token: \"X\", near: \"X Y\"" window from the
130                // offending token's span and the next token's span (if any).
131                let offending_span = stream
132                    .peek_span()
133                    .expect("peek_span must yield when peek did");
134                let bad = &expr[offending_span.start..offending_span.end()];
135                stream.next();
136                let next_span = stream.peek_span();
137                let near = near_window_parser(expr, offending_span, next_span);
138                return Err(format!(
139                    r#"Invalid UpdateExpression: Syntax error; token: "{bad}", near: "{near}""#
140                ));
141            }
142            None => break,
143        }
144    }
145
146    Ok(UpdateExpr {
147        set_actions,
148        remove_actions,
149        add_actions,
150        delete_actions,
151    })
152}
153
154/// Wrap a sub-parser error with the standard syntax error prefix,
155/// unless it already has a recognised higher-level prefix.
156fn wrap_syntax_error(err: String) -> String {
157    if err.starts_with("Invalid UpdateExpression:") {
158        err
159    } else if err.starts_with("Attribute name is a reserved keyword") {
160        format!("Invalid UpdateExpression: {err}")
161    } else {
162        format!("Invalid UpdateExpression: Syntax error; {err}")
163    }
164}
165
166/// Walk an UpdateExpr and track all attribute name and value references
167/// without actually evaluating or modifying any item. This is used for
168/// pre-validation: checking that all referenced names/values are defined,
169/// and detecting unused names/values.
170pub fn track_references(
171    expr: &UpdateExpr,
172    tracker: &TrackedExpressionAttributes,
173) -> Result<(), String> {
174    // Collect all target paths for overlap/conflict detection
175    let mut all_target_paths: Vec<Vec<PathElement>> = Vec::new();
176
177    for action in &expr.set_actions {
178        track_path_refs(&action.path, tracker)?;
179        track_set_value_refs(&action.value, tracker)?;
180        all_target_paths.push(resolve_tracked_path(&action.path, tracker));
181    }
182    for path in &expr.remove_actions {
183        track_path_refs(path, tracker)?;
184        all_target_paths.push(resolve_tracked_path(path, tracker));
185    }
186    for action in &expr.add_actions {
187        track_path_refs(&action.path, tracker)?;
188        let val = tracker.resolve_value(&action.value_ref)?;
189        // Validate ADD operand type statically
190        validate_add_type(val)?;
191        all_target_paths.push(resolve_tracked_path(&action.path, tracker));
192    }
193    for action in &expr.delete_actions {
194        track_path_refs(&action.path, tracker)?;
195        let val = tracker.resolve_value(&action.value_ref)?;
196        // Validate DELETE operand type statically
197        validate_delete_type(val)?;
198        all_target_paths.push(resolve_tracked_path(&action.path, tracker));
199    }
200
201    // Static type validation for SET value expressions
202    for action in &expr.set_actions {
203        validate_set_value_types(&action.value, tracker)?;
204    }
205
206    // Check for overlapping/conflicting paths
207    check_path_overlaps(&all_target_paths)?;
208
209    Ok(())
210}
211
212/// Validate that an ADD operand has a compatible type.
213fn validate_add_type(val: &crate::types::AttributeValue) -> Result<(), String> {
214    use crate::types::AttributeValue;
215    match val {
216        AttributeValue::N(_)
217        | AttributeValue::SS(_)
218        | AttributeValue::NS(_)
219        | AttributeValue::BS(_) => Ok(()),
220        _ => Err(format!(
221            "Invalid UpdateExpression: Incorrect operand type for operator or function; \
222             operator: ADD, operand type: {}",
223            dynamo_type_name(val)
224        )),
225    }
226}
227
228/// Validate that a DELETE operand has a compatible type.
229fn validate_delete_type(val: &crate::types::AttributeValue) -> Result<(), String> {
230    use crate::types::AttributeValue;
231    match val {
232        AttributeValue::SS(_) | AttributeValue::NS(_) | AttributeValue::BS(_) => Ok(()),
233        _ => Err(format!(
234            "Invalid UpdateExpression: Incorrect operand type for operator or function; \
235             operator: DELETE, operand type: {}",
236            dynamo_type_name(val)
237        )),
238    }
239}
240
241/// Map an AttributeValue to its DynamoDB type name for error messages.
242fn dynamo_type_name(val: &crate::types::AttributeValue) -> &'static str {
243    use crate::types::AttributeValue;
244    match val {
245        AttributeValue::S(_) => "STRING",
246        AttributeValue::N(_) => "NUMBER",
247        AttributeValue::B(_) => "BINARY",
248        AttributeValue::BOOL(_) => "BOOLEAN",
249        AttributeValue::NULL(_) => "NULL",
250        AttributeValue::SS(_) => "SS",
251        AttributeValue::NS(_) => "NS",
252        AttributeValue::BS(_) => "BS",
253        AttributeValue::L(_) => "LIST",
254        AttributeValue::M(_) => "MAP",
255    }
256}
257
258/// Validate types for SET value expressions (arithmetic, list_append).
259fn validate_set_value_types(
260    value: &SetValue,
261    tracker: &TrackedExpressionAttributes,
262) -> Result<(), String> {
263    match value {
264        SetValue::Operand(op) => validate_set_operand_types(op, tracker),
265        SetValue::Plus(left, right) => {
266            validate_arithmetic_operand(left, "+", tracker)?;
267            validate_arithmetic_operand(right, "+", tracker)
268        }
269        SetValue::Minus(left, right) => {
270            validate_arithmetic_operand(left, "-", tracker)?;
271            validate_arithmetic_operand(right, "-", tracker)
272        }
273    }
274}
275
276/// Validate that an operand used in + or - is a number (if it's a value ref).
277fn validate_arithmetic_operand(
278    operand: &SetOperand,
279    op: &str,
280    tracker: &TrackedExpressionAttributes,
281) -> Result<(), String> {
282    use crate::types::AttributeValue;
283    match operand {
284        SetOperand::ValueRef(name) => {
285            let val = tracker.resolve_value(name)?;
286            if !matches!(val, AttributeValue::N(_)) {
287                return Err(format!(
288                    "Invalid UpdateExpression: Incorrect operand type for operator or function; \
289                     operator or function: {op}, operand type: {}",
290                    dynamo_type_name(val)
291                ));
292            }
293            Ok(())
294        }
295        SetOperand::IfNotExists(_, default) => validate_set_operand_types(default, tracker),
296        SetOperand::ListAppend(a, b) => {
297            validate_list_append_operand(a, tracker)?;
298            validate_list_append_operand(b, tracker)
299        }
300        SetOperand::Path(_) => Ok(()), // Path types checked at runtime
301    }
302}
303
304/// Validate types for a set operand (recursively).
305fn validate_set_operand_types(
306    operand: &SetOperand,
307    tracker: &TrackedExpressionAttributes,
308) -> Result<(), String> {
309    match operand {
310        SetOperand::ListAppend(a, b) => {
311            validate_list_append_operand(a, tracker)?;
312            validate_list_append_operand(b, tracker)
313        }
314        SetOperand::IfNotExists(_, default) => validate_set_operand_types(default, tracker),
315        _ => Ok(()),
316    }
317}
318
319/// Validate a list_append operand is a list if it's a value ref.
320fn validate_list_append_operand(
321    operand: &SetOperand,
322    tracker: &TrackedExpressionAttributes,
323) -> Result<(), String> {
324    use crate::types::AttributeValue;
325    if let SetOperand::ValueRef(name) = operand {
326        let val = tracker.resolve_value(name)?;
327        if !matches!(val, AttributeValue::L(_)) {
328            return Err(format!(
329                "Invalid UpdateExpression: Incorrect operand type for operator or function; \
330                 operator or function: list_append, operand type: {}",
331                dynamo_type_name(val)
332            ));
333        }
334    }
335    Ok(())
336}
337
338/// Resolve path elements to their final names (expanding #name refs).
339fn resolve_tracked_path(
340    path: &[PathElement],
341    tracker: &TrackedExpressionAttributes,
342) -> Vec<PathElement> {
343    path.iter()
344        .map(|elem| {
345            if let PathElement::Attribute(name) = elem {
346                if name.starts_with('#') {
347                    if let Ok(resolved) = tracker.resolve_name(name) {
348                        return PathElement::Attribute(resolved);
349                    }
350                }
351            }
352            elem.clone()
353        })
354        .collect()
355}
356
357/// Format a path for error messages in dynalite format: [a, b, [1], c].
358fn format_path_for_error(path: &[PathElement]) -> String {
359    let parts: Vec<String> = path
360        .iter()
361        .map(|elem| match elem {
362            PathElement::Attribute(name) => name.clone(),
363            PathElement::Index(i) => format!("[{i}]"),
364        })
365        .collect();
366    format!("[{}]", parts.join(", "))
367}
368
369/// Check for overlapping or conflicting document paths.
370///
371/// Two paths overlap if one is a prefix of the other (e.g., `a.b` and `a.b.c`).
372/// Two paths conflict if they share elements but diverge in type at the same
373/// position (e.g., `a[3].c` and `a.c[3]`).
374fn check_path_overlaps(paths: &[Vec<PathElement>]) -> Result<(), String> {
375    for i in 0..paths.len() {
376        for j in (i + 1)..paths.len() {
377            let a = &paths[i];
378            let b = &paths[j];
379            let min_len = a.len().min(b.len());
380
381            // Check common prefix length
382            let mut common = 0;
383            for k in 0..min_len {
384                if a[k] == b[k] {
385                    common += 1;
386                } else {
387                    break;
388                }
389            }
390
391            if common == 0 {
392                continue;
393            }
394
395            // If one path is a prefix of the other, they overlap
396            if common == a.len() || common == b.len() {
397                let (shorter, longer) = if a.len() <= b.len() { (a, b) } else { (b, a) };
398                return Err(format!(
399                    "Invalid UpdateExpression: Two document paths overlap with each other; \
400                     must remove or rewrite one of these paths; \
401                     path one: {}, path two: {}",
402                    format_path_for_error(longer),
403                    format_path_for_error(shorter)
404                ));
405            }
406
407            // If paths share a prefix but diverge, they conflict
408            if common > 0 && common < min_len && a == b {
409                return Err(format!(
410                    "Invalid UpdateExpression: Two document paths conflict with each other; \
411                     must remove or rewrite one of these paths; \
412                     path one: {}, path two: {}",
413                    format_path_for_error(a),
414                    format_path_for_error(b)
415                ));
416            }
417        }
418    }
419    Ok(())
420}
421
422fn track_path_refs(
423    path: &[PathElement],
424    tracker: &TrackedExpressionAttributes,
425) -> Result<(), String> {
426    for elem in path {
427        if let PathElement::Attribute(name) = elem {
428            if name.starts_with('#') {
429                tracker.resolve_name(name)?;
430            }
431        }
432    }
433    Ok(())
434}
435
436fn track_set_value_refs(
437    value: &SetValue,
438    tracker: &TrackedExpressionAttributes,
439) -> Result<(), String> {
440    match value {
441        SetValue::Operand(op) => track_set_operand_refs(op, tracker),
442        SetValue::Plus(left, right) | SetValue::Minus(left, right) => {
443            track_set_operand_refs(left, tracker)?;
444            track_set_operand_refs(right, tracker)
445        }
446    }
447}
448
449fn track_set_operand_refs(
450    operand: &SetOperand,
451    tracker: &TrackedExpressionAttributes,
452) -> Result<(), String> {
453    match operand {
454        SetOperand::Path(path) => track_path_refs(path, tracker),
455        SetOperand::ValueRef(name) => {
456            tracker.resolve_value(name)?;
457            Ok(())
458        }
459        SetOperand::IfNotExists(path, default) => {
460            track_path_refs(path, tracker)?;
461            track_set_operand_refs(default, tracker)
462        }
463        SetOperand::ListAppend(a, b) => {
464            track_set_operand_refs(a, tracker)?;
465            track_set_operand_refs(b, tracker)
466        }
467    }
468}
469
470/// Apply an update expression to an item (mutating it in place), tracking attribute usage.
471pub fn apply(
472    item: &mut HashMap<String, AttributeValue>,
473    expr: &UpdateExpr,
474    tracker: &TrackedExpressionAttributes,
475) -> Result<(), String> {
476    // Process SET actions
477    for action in &expr.set_actions {
478        let resolved_path = resolve_path_elements(&action.path, tracker)?;
479        let value = evaluate_set_value(&action.value, item, tracker)?;
480        set_path(item, &resolved_path, value)?;
481    }
482
483    // Process REMOVE actions
484    for path in &expr.remove_actions {
485        let resolved_path = resolve_path_elements(path, tracker)?;
486        remove_path(item, &resolved_path)?;
487    }
488
489    // Process ADD actions
490    for action in &expr.add_actions {
491        let resolved_path = resolve_path_elements(&action.path, tracker)?;
492        let add_val = tracker.resolve_value(&action.value_ref)?.clone();
493        apply_add(item, &resolved_path, &add_val).map_err(|_| {
494            "An operand in the update expression has an incorrect data type".to_string()
495        })?;
496    }
497
498    // Process DELETE actions
499    for action in &expr.delete_actions {
500        let resolved_path = resolve_path_elements(&action.path, tracker)?;
501        let del_val = tracker.resolve_value(&action.value_ref)?.clone();
502        apply_delete(item, &resolved_path, &del_val).map_err(|_| {
503            "An operand in the update expression has an incorrect data type".to_string()
504        })?;
505    }
506
507    Ok(())
508}
509
510// ---------------------------------------------------------------------------
511// SET value evaluation
512// ---------------------------------------------------------------------------
513
514fn evaluate_set_value(
515    value: &SetValue,
516    item: &HashMap<String, AttributeValue>,
517    tracker: &TrackedExpressionAttributes,
518) -> Result<AttributeValue, String> {
519    match value {
520        SetValue::Operand(op) => evaluate_set_operand(op, item, tracker),
521        SetValue::Plus(left, right) => {
522            let lv = evaluate_set_operand(left, item, tracker)?;
523            let rv = evaluate_set_operand(right, item, tracker)?;
524            match (&lv, &rv) {
525                (AttributeValue::N(a), AttributeValue::N(b)) => {
526                    use bigdecimal::BigDecimal;
527                    use std::str::FromStr;
528                    let da = BigDecimal::from_str(a).map_err(|_| format!("Invalid number: {a}"))?;
529                    let db = BigDecimal::from_str(b).map_err(|_| format!("Invalid number: {b}"))?;
530                    let result = &da + &db;
531                    Ok(AttributeValue::N(format_number(&result)))
532                }
533                _ => Err("Operands for + must be numbers".to_string()),
534            }
535        }
536        SetValue::Minus(left, right) => {
537            let lv = evaluate_set_operand(left, item, tracker)?;
538            let rv = evaluate_set_operand(right, item, tracker)?;
539            match (&lv, &rv) {
540                (AttributeValue::N(a), AttributeValue::N(b)) => {
541                    use bigdecimal::BigDecimal;
542                    use std::str::FromStr;
543                    let da = BigDecimal::from_str(a).map_err(|_| format!("Invalid number: {a}"))?;
544                    let db = BigDecimal::from_str(b).map_err(|_| format!("Invalid number: {b}"))?;
545                    let result = &da - &db;
546                    Ok(AttributeValue::N(format_number(&result)))
547                }
548                _ => Err("Operands for - must be numbers".to_string()),
549            }
550        }
551    }
552}
553
554fn evaluate_set_operand(
555    operand: &SetOperand,
556    item: &HashMap<String, AttributeValue>,
557    tracker: &TrackedExpressionAttributes,
558) -> Result<AttributeValue, String> {
559    match operand {
560        SetOperand::Path(path) => {
561            let resolved = resolve_path_elements(path, tracker)?;
562            resolve_path(item, &resolved).ok_or_else(|| {
563                "The provided expression refers to an attribute that does not exist in the item"
564                    .to_string()
565            })
566        }
567        SetOperand::ValueRef(name) => Ok(tracker.resolve_value(name)?.clone()),
568        SetOperand::IfNotExists(path, default) => {
569            let resolved = resolve_path_elements(path, tracker)?;
570            match resolve_path(item, &resolved) {
571                Some(existing) => Ok(existing),
572                None => evaluate_set_operand(default, item, tracker),
573            }
574        }
575        SetOperand::ListAppend(list1, list2) => {
576            let v1 = evaluate_set_operand(list1, item, tracker)?;
577            let v2 = evaluate_set_operand(list2, item, tracker)?;
578            match (v1, v2) {
579                (AttributeValue::L(mut a), AttributeValue::L(b)) => {
580                    a.extend(b);
581                    Ok(AttributeValue::L(a))
582                }
583                _ => Err("list_append requires two list operands".to_string()),
584            }
585        }
586    }
587}
588
589// ---------------------------------------------------------------------------
590// ADD action
591// ---------------------------------------------------------------------------
592
593/// Public wrapper for use by legacy `AttributeUpdates` support.
594pub fn apply_add_public(
595    item: &mut HashMap<String, AttributeValue>,
596    path: &[PathElement],
597    add_val: &AttributeValue,
598) -> Result<(), String> {
599    apply_add(item, path, add_val)
600}
601
602fn apply_add(
603    item: &mut HashMap<String, AttributeValue>,
604    path: &[PathElement],
605    add_val: &AttributeValue,
606) -> Result<(), String> {
607    let existing = resolve_path(item, path);
608
609    match (existing, add_val) {
610        // Number: add to existing number or create
611        (Some(AttributeValue::N(existing_n)), AttributeValue::N(add_n)) => {
612            use bigdecimal::BigDecimal;
613            use std::str::FromStr;
614            let de = BigDecimal::from_str(&existing_n)
615                .map_err(|_| format!("Invalid number: {existing_n}"))?;
616            let da = BigDecimal::from_str(add_n).map_err(|_| format!("Invalid number: {add_n}"))?;
617            let result = &de + &da;
618            set_path(item, path, AttributeValue::N(format_number(&result)))
619        }
620        (None, AttributeValue::N(_)) => {
621            // Create with the provided value
622            set_path(item, path, add_val.clone())
623        }
624
625        // String set: union
626        (Some(AttributeValue::SS(mut existing_set)), AttributeValue::SS(add_set)) => {
627            for s in add_set {
628                if !existing_set.contains(s) {
629                    existing_set.push(s.clone());
630                }
631            }
632            set_path(item, path, AttributeValue::SS(existing_set))
633        }
634        (None, AttributeValue::SS(_)) => set_path(item, path, add_val.clone()),
635
636        // Number set: union
637        (Some(AttributeValue::NS(mut existing_set)), AttributeValue::NS(add_set)) => {
638            for n in add_set {
639                if !existing_set.contains(n) {
640                    existing_set.push(n.clone());
641                }
642            }
643            set_path(item, path, AttributeValue::NS(existing_set))
644        }
645        (None, AttributeValue::NS(_)) => set_path(item, path, add_val.clone()),
646
647        // Binary set: union
648        (Some(AttributeValue::BS(mut existing_set)), AttributeValue::BS(add_set)) => {
649            for b in add_set {
650                if !existing_set.contains(b) {
651                    existing_set.push(b.clone());
652                }
653            }
654            set_path(item, path, AttributeValue::BS(existing_set))
655        }
656        (None, AttributeValue::BS(_)) => set_path(item, path, add_val.clone()),
657
658        // List: append elements (legacy AttributeUpdates behaviour)
659        (Some(AttributeValue::L(mut existing_list)), AttributeValue::L(add_list)) => {
660            existing_list.extend(add_list.iter().cloned());
661            set_path(item, path, AttributeValue::L(existing_list))
662        }
663        (None, AttributeValue::L(_)) => set_path(item, path, add_val.clone()),
664
665        _ => Err("Type mismatch for attribute to update".to_string()),
666    }
667}
668
669// ---------------------------------------------------------------------------
670// DELETE action
671// ---------------------------------------------------------------------------
672
673/// Public wrapper for use by legacy `AttributeUpdates` support.
674pub fn apply_delete_public(
675    item: &mut HashMap<String, AttributeValue>,
676    path: &[PathElement],
677    del_val: &AttributeValue,
678) -> Result<(), String> {
679    apply_delete(item, path, del_val)
680}
681
682fn apply_delete(
683    item: &mut HashMap<String, AttributeValue>,
684    path: &[PathElement],
685    del_val: &AttributeValue,
686) -> Result<(), String> {
687    let existing = resolve_path(item, path);
688
689    match (existing, del_val) {
690        (Some(AttributeValue::SS(existing_set)), AttributeValue::SS(del_set)) => {
691            let new_set: Vec<String> = existing_set
692                .into_iter()
693                .filter(|s| !del_set.contains(s))
694                .collect();
695            if new_set.is_empty() {
696                remove_path(item, path)
697            } else {
698                set_path(item, path, AttributeValue::SS(new_set))
699            }
700        }
701        (Some(AttributeValue::NS(existing_set)), AttributeValue::NS(del_set)) => {
702            let new_set: Vec<String> = existing_set
703                .into_iter()
704                .filter(|n| !del_set.contains(n))
705                .collect();
706            if new_set.is_empty() {
707                remove_path(item, path)
708            } else {
709                set_path(item, path, AttributeValue::NS(new_set))
710            }
711        }
712        (Some(AttributeValue::BS(existing_set)), AttributeValue::BS(del_set)) => {
713            let new_set: Vec<Vec<u8>> = existing_set
714                .into_iter()
715                .filter(|b| !del_set.contains(b))
716                .collect();
717            if new_set.is_empty() {
718                remove_path(item, path)
719            } else {
720                set_path(item, path, AttributeValue::BS(new_set))
721            }
722        }
723        (None, _) => Ok(()), // Nothing to delete from
724        _ => Err("Type mismatch for attribute to update".to_string()),
725    }
726}
727
728// ---------------------------------------------------------------------------
729// Parser
730// ---------------------------------------------------------------------------
731
732fn parse_set_clause(stream: &mut TokenStream, actions: &mut Vec<SetAction>) -> Result<(), String> {
733    actions.push(parse_set_action(stream)?);
734    while matches!(stream.peek(), Some(Token::Comma)) {
735        stream.next();
736        actions.push(parse_set_action(stream)?);
737    }
738    Ok(())
739}
740
741fn parse_set_action(stream: &mut TokenStream) -> Result<SetAction, String> {
742    let path = parse_raw_path(stream)?;
743    stream.expect(&Token::Eq)?;
744    let value = parse_set_value(stream)?;
745    Ok(SetAction { path, value })
746}
747
748fn parse_set_value(stream: &mut TokenStream) -> Result<SetValue, String> {
749    let left = parse_set_operand(stream)?;
750
751    match stream.peek() {
752        Some(Token::Plus) => {
753            stream.next();
754            let right = parse_set_operand(stream)?;
755            Ok(SetValue::Plus(left, right))
756        }
757        Some(Token::Minus) => {
758            stream.next();
759            let right = parse_set_operand(stream)?;
760            Ok(SetValue::Minus(left, right))
761        }
762        _ => Ok(SetValue::Operand(left)),
763    }
764}
765
766fn parse_set_operand(stream: &mut TokenStream) -> Result<SetOperand, String> {
767    // Check for functions: if_not_exists, list_append
768    if let Some(Token::Identifier(name)) = stream.peek() {
769        let func_name = name.to_lowercase();
770        let orig_name = name.clone();
771        match func_name.as_str() {
772            "if_not_exists" => {
773                stream.next();
774                stream.expect(&Token::LParen)?;
775
776                // First argument must be a document path (not a value ref or function)
777                match stream.peek() {
778                    Some(Token::ValueRef(_)) => {
779                        return Err(
780                            "Invalid UpdateExpression: Operator or function requires a document path; \
781                             operator or function: if_not_exists".to_string()
782                        );
783                    }
784                    Some(Token::Identifier(fname))
785                        if fname.to_lowercase() == "if_not_exists"
786                            || fname.to_lowercase() == "list_append" =>
787                    {
788                        return Err(
789                            "Invalid UpdateExpression: Operator or function requires a document path; \
790                             operator or function: if_not_exists".to_string()
791                        );
792                    }
793                    _ => {}
794                }
795
796                let path = parse_raw_path(stream)?;
797
798                // Check for correct number of operands
799                if !matches!(stream.peek(), Some(Token::Comma)) {
800                    return Err(
801                        "Invalid UpdateExpression: Incorrect number of operands for operator or function; \
802                         operator or function: if_not_exists, number of operands: 1".to_string()
803                    );
804                }
805                stream.expect(&Token::Comma)?;
806                let default = parse_set_operand(stream)?;
807                stream.expect(&Token::RParen)?;
808                return Ok(SetOperand::IfNotExists(path, Box::new(default)));
809            }
810            "list_append" => {
811                stream.next();
812                stream.expect(&Token::LParen)?;
813                let list1 = parse_set_operand(stream)?;
814
815                // Check for correct number of operands
816                if !matches!(stream.peek(), Some(Token::Comma)) {
817                    return Err(
818                        "Invalid UpdateExpression: Incorrect number of operands for operator or function; \
819                         operator or function: list_append, number of operands: 1".to_string()
820                    );
821                }
822                stream.expect(&Token::Comma)?;
823                let list2 = parse_set_operand(stream)?;
824                stream.expect(&Token::RParen)?;
825                return Ok(SetOperand::ListAppend(Box::new(list1), Box::new(list2)));
826            }
827            _ => {
828                // Check if this looks like a function call (identifier followed by '(')
829                // If so, report "Invalid function name" for unknown functions.
830                let saved_pos = stream.pos();
831                stream.next();
832                if matches!(stream.peek(), Some(Token::LParen)) {
833                    return Err(format!(
834                        "Invalid UpdateExpression: Invalid function name; function: {}",
835                        orig_name
836                    ));
837                }
838                // Rewind — not a function call, treat as path
839                stream.set_pos(saved_pos);
840            }
841        }
842    }
843
844    match stream.peek() {
845        Some(Token::ValueRef(_)) => {
846            if let Some(Token::ValueRef(name)) = stream.next().cloned() {
847                Ok(SetOperand::ValueRef(name))
848            } else {
849                unreachable!()
850            }
851        }
852        Some(Token::Identifier(_)) | Some(Token::NameRef(_)) => {
853            let path = parse_raw_path(stream)?;
854            Ok(SetOperand::Path(path))
855        }
856        Some(t) => Err(format!("Expected operand in SET, got {t}")),
857        None => Err("Expected operand in SET, got end of expression".to_string()),
858    }
859}
860
861fn parse_remove_clause(
862    stream: &mut TokenStream,
863    actions: &mut Vec<Vec<PathElement>>,
864) -> Result<(), String> {
865    actions.push(parse_raw_path(stream)?);
866    while matches!(stream.peek(), Some(Token::Comma)) {
867        stream.next();
868        actions.push(parse_raw_path(stream)?);
869    }
870    Ok(())
871}
872
873fn parse_add_clause(stream: &mut TokenStream, actions: &mut Vec<AddAction>) -> Result<(), String> {
874    actions.push(parse_add_action(stream)?);
875    while matches!(stream.peek(), Some(Token::Comma)) {
876        stream.next();
877        actions.push(parse_add_action(stream)?);
878    }
879    Ok(())
880}
881
882fn parse_add_action(stream: &mut TokenStream) -> Result<AddAction, String> {
883    let path = parse_raw_path(stream)?;
884    match stream.next() {
885        Some(Token::ValueRef(name)) => Ok(AddAction {
886            path,
887            value_ref: name.clone(),
888        }),
889        Some(t) => Err(format!("Expected value reference in ADD, got {t}")),
890        None => Err("Expected value reference in ADD, got end of expression".to_string()),
891    }
892}
893
894fn parse_delete_clause(
895    stream: &mut TokenStream,
896    actions: &mut Vec<DeleteAction>,
897) -> Result<(), String> {
898    actions.push(parse_delete_action(stream)?);
899    while matches!(stream.peek(), Some(Token::Comma)) {
900        stream.next();
901        actions.push(parse_delete_action(stream)?);
902    }
903    Ok(())
904}
905
906fn parse_delete_action(stream: &mut TokenStream) -> Result<DeleteAction, String> {
907    let path = parse_raw_path(stream)?;
908    match stream.next() {
909        Some(Token::ValueRef(name)) => Ok(DeleteAction {
910            path,
911            value_ref: name.clone(),
912        }),
913        Some(t) => Err(format!("Expected value reference in DELETE, got {t}")),
914        None => Err("Expected value reference in DELETE, got end of expression".to_string()),
915    }
916}
917
918/// Format a BigDecimal number, stripping unnecessary trailing zeros.
919/// DynamoDB returns numbers without scientific notation.
920fn format_number(n: &bigdecimal::BigDecimal) -> String {
921    let normalized = n.normalized();
922    // Force scale >= 0 so BigDecimal renders without scientific notation.
923    // When the exponent is negative (large integer like 1e38), with_scale(0)
924    // expands to full decimal digits.
925    if normalized.as_bigint_and_exponent().1 < 0 {
926        normalized.with_scale(0).to_string()
927    } else {
928        normalized.to_string()
929    }
930}
931
932#[cfg(test)]
933mod tests {
934    use super::*;
935
936    fn make_item(pairs: &[(&str, AttributeValue)]) -> HashMap<String, AttributeValue> {
937        pairs
938            .iter()
939            .map(|(k, v)| (k.to_string(), v.clone()))
940            .collect()
941    }
942
943    fn vals(pairs: &[(&str, AttributeValue)]) -> Option<HashMap<String, AttributeValue>> {
944        Some(make_item(pairs))
945    }
946
947    fn make_tracker<'a>(
948        names: &'a Option<HashMap<String, String>>,
949        values: &'a Option<HashMap<String, AttributeValue>>,
950    ) -> TrackedExpressionAttributes<'a> {
951        TrackedExpressionAttributes::new(names, values)
952    }
953
954    #[test]
955    fn test_set_simple() {
956        let expr = parse("SET label = :val").unwrap();
957        assert_eq!(expr.set_actions.len(), 1);
958        assert!(expr.remove_actions.is_empty());
959    }
960
961    #[test]
962    fn test_set_multiple() {
963        let expr = parse("SET a = :v1, b = :v2").unwrap();
964        assert_eq!(expr.set_actions.len(), 2);
965    }
966
967    #[test]
968    fn test_set_arithmetic_plus() {
969        let expr = parse("SET tally = tally + :inc").unwrap();
970        let mut item = make_item(&[
971            ("pk", AttributeValue::S("k".into())),
972            ("tally", AttributeValue::N("10".into())),
973        ]);
974        let av = vals(&[(":inc", AttributeValue::N("5".into()))]);
975        let no_names = None;
976        let tracker = make_tracker(&no_names, &av);
977        apply(&mut item, &expr, &tracker).unwrap();
978        assert_eq!(item["tally"], AttributeValue::N("15".into()));
979    }
980
981    #[test]
982    fn test_set_arithmetic_minus() {
983        let expr = parse("SET price = price - :discount").unwrap();
984        let mut item = make_item(&[
985            ("pk", AttributeValue::S("k".into())),
986            ("price", AttributeValue::N("100".into())),
987        ]);
988        let av = vals(&[(":discount", AttributeValue::N("25".into()))]);
989        let no_names = None;
990        let tracker = make_tracker(&no_names, &av);
991        apply(&mut item, &expr, &tracker).unwrap();
992        assert_eq!(item["price"], AttributeValue::N("75".into()));
993    }
994
995    #[test]
996    fn test_set_if_not_exists() {
997        let expr = parse("SET hits = if_not_exists(hits, :zero)").unwrap();
998        let mut item = make_item(&[("pk", AttributeValue::S("k".into()))]);
999        let av = vals(&[(":zero", AttributeValue::N("0".into()))]);
1000        let no_names = None;
1001        let tracker = make_tracker(&no_names, &av);
1002        apply(&mut item, &expr, &tracker).unwrap();
1003        assert_eq!(item["hits"], AttributeValue::N("0".into()));
1004
1005        // Apply again — existing value should be preserved
1006        let tracker2 = make_tracker(&no_names, &av);
1007        apply(&mut item, &expr, &tracker2).unwrap();
1008        assert_eq!(item["hits"], AttributeValue::N("0".into()));
1009    }
1010
1011    #[test]
1012    fn test_set_list_append() {
1013        let expr = parse("SET entries = list_append(entries, :new)").unwrap();
1014        let mut item = make_item(&[
1015            ("pk", AttributeValue::S("k".into())),
1016            (
1017                "entries",
1018                AttributeValue::L(vec![AttributeValue::S("a".into())]),
1019            ),
1020        ]);
1021        let av = vals(&[(
1022            ":new",
1023            AttributeValue::L(vec![AttributeValue::S("b".into())]),
1024        )]);
1025        let no_names = None;
1026        let tracker = make_tracker(&no_names, &av);
1027        apply(&mut item, &expr, &tracker).unwrap();
1028        if let AttributeValue::L(list) = &item["entries"] {
1029            assert_eq!(list.len(), 2);
1030        } else {
1031            panic!("Expected list");
1032        }
1033    }
1034
1035    #[test]
1036    fn test_remove() {
1037        let expr = parse("REMOVE attr1, attr2").unwrap();
1038        let mut item = make_item(&[
1039            ("pk", AttributeValue::S("k".into())),
1040            ("attr1", AttributeValue::S("a".into())),
1041            ("attr2", AttributeValue::S("b".into())),
1042            ("attr3", AttributeValue::S("c".into())),
1043        ]);
1044        let no_names = None;
1045        let no_values = None;
1046        let tracker = make_tracker(&no_names, &no_values);
1047        apply(&mut item, &expr, &tracker).unwrap();
1048        assert!(!item.contains_key("attr1"));
1049        assert!(!item.contains_key("attr2"));
1050        assert!(item.contains_key("attr3"));
1051    }
1052
1053    #[test]
1054    fn test_add_number() {
1055        let expr = parse("ADD tally :inc").unwrap();
1056        let mut item = make_item(&[
1057            ("pk", AttributeValue::S("k".into())),
1058            ("tally", AttributeValue::N("10".into())),
1059        ]);
1060        let av = vals(&[(":inc", AttributeValue::N("5".into()))]);
1061        let no_names = None;
1062        let tracker = make_tracker(&no_names, &av);
1063        apply(&mut item, &expr, &tracker).unwrap();
1064        assert_eq!(item["tally"], AttributeValue::N("15".into()));
1065    }
1066
1067    #[test]
1068    fn test_add_number_create() {
1069        let expr = parse("ADD tally :val").unwrap();
1070        let mut item = make_item(&[("pk", AttributeValue::S("k".into()))]);
1071        let av = vals(&[(":val", AttributeValue::N("1".into()))]);
1072        let no_names = None;
1073        let tracker = make_tracker(&no_names, &av);
1074        apply(&mut item, &expr, &tracker).unwrap();
1075        assert_eq!(item["tally"], AttributeValue::N("1".into()));
1076    }
1077
1078    #[test]
1079    fn test_add_string_set() {
1080        let expr = parse("ADD colors :new_colors").unwrap();
1081        let mut item = make_item(&[
1082            ("pk", AttributeValue::S("k".into())),
1083            (
1084                "colors",
1085                AttributeValue::SS(vec!["red".into(), "blue".into()]),
1086            ),
1087        ]);
1088        let av = vals(&[(
1089            ":new_colors",
1090            AttributeValue::SS(vec!["blue".into(), "green".into()]),
1091        )]);
1092        let no_names = None;
1093        let tracker = make_tracker(&no_names, &av);
1094        apply(&mut item, &expr, &tracker).unwrap();
1095        if let AttributeValue::SS(set) = &item["colors"] {
1096            assert_eq!(set.len(), 3); // red, blue, green (blue deduplicated)
1097            assert!(set.contains(&"green".to_string()));
1098        } else {
1099            panic!("Expected SS");
1100        }
1101    }
1102
1103    #[test]
1104    fn test_delete_string_set() {
1105        let expr = parse("DELETE colors :remove").unwrap();
1106        let mut item = make_item(&[
1107            ("pk", AttributeValue::S("k".into())),
1108            (
1109                "colors",
1110                AttributeValue::SS(vec!["red".into(), "blue".into(), "green".into()]),
1111            ),
1112        ]);
1113        let av = vals(&[(
1114            ":remove",
1115            AttributeValue::SS(vec!["blue".into(), "green".into()]),
1116        )]);
1117        let no_names = None;
1118        let tracker = make_tracker(&no_names, &av);
1119        apply(&mut item, &expr, &tracker).unwrap();
1120        if let AttributeValue::SS(set) = &item["colors"] {
1121            assert_eq!(set, &vec!["red".to_string()]);
1122        } else {
1123            panic!("Expected SS");
1124        }
1125    }
1126
1127    #[test]
1128    fn test_combined_set_remove() {
1129        let expr = parse("SET label = :name REMOVE old_attr").unwrap();
1130        assert_eq!(expr.set_actions.len(), 1);
1131        assert_eq!(expr.remove_actions.len(), 1);
1132    }
1133
1134    #[test]
1135    fn test_duplicate_clause_error() {
1136        let result = parse("SET a = :v SET b = :w");
1137        assert!(result.is_err());
1138        assert!(result.unwrap_err().contains("only be used once"));
1139    }
1140}