1use crate::expressions::condition::parse_raw_path;
6use crate::expressions::tokenizer::{
7 Token, TokenStream, near_window_parser, near_window_tokenizer, tokenize,
8};
9use crate::expressions::{
10 PathElement, TrackedExpressionAttributes, remove_path, resolve_path, resolve_path_elements,
11 set_path,
12};
13use crate::types::AttributeValue;
14use std::collections::HashMap;
15
16#[derive(Debug)]
18pub struct UpdateExpr {
19 pub set_actions: Vec<SetAction>,
20 pub remove_actions: Vec<Vec<PathElement>>,
21 pub add_actions: Vec<AddAction>,
22 pub delete_actions: Vec<DeleteAction>,
23}
24
25#[derive(Debug)]
27pub struct SetAction {
28 pub path: Vec<PathElement>,
29 pub value: SetValue,
30}
31
32#[derive(Debug)]
34pub enum SetValue {
35 Operand(SetOperand),
37 Plus(SetOperand, SetOperand),
39 Minus(SetOperand, SetOperand),
41}
42
43#[derive(Debug)]
45pub enum SetOperand {
46 Path(Vec<PathElement>),
47 ValueRef(String),
48 IfNotExists(Vec<PathElement>, Box<SetOperand>),
49 ListAppend(Box<SetOperand>, Box<SetOperand>),
50 Group(Box<SetValue>),
52}
53
54#[derive(Debug)]
56pub struct AddAction {
57 pub path: Vec<PathElement>,
58 pub value_ref: String,
59}
60
61#[derive(Debug)]
63pub struct DeleteAction {
64 pub path: Vec<PathElement>,
65 pub value_ref: String,
66}
67
68pub fn parse(expr: &str) -> Result<UpdateExpr, String> {
70 let tokens = match tokenize(expr) {
71 Ok(t) => t,
72 Err(err) => {
73 let bad = &expr[err.position..err.position + err.bad_len];
77 let near = near_window_tokenizer(expr, err.position);
78 return Err(format!(
79 r#"Invalid UpdateExpression: Syntax error; token: "{bad}", near: "{near}""#
80 ));
81 }
82 };
83 let mut stream = TokenStream::new(tokens);
84
85 let mut set_actions = Vec::new();
86 let mut remove_actions = Vec::new();
87 let mut add_actions = Vec::new();
88 let mut delete_actions = Vec::new();
89
90 let mut seen_set = false;
91 let mut seen_remove = false;
92 let mut seen_add = false;
93 let mut seen_delete = false;
94
95 while !stream.at_end() {
96 match stream.peek() {
97 Some(Token::Set) => {
98 if seen_set {
99 return Err("Invalid UpdateExpression: The \"SET\" section can only be used once in an update expression;".to_string());
100 }
101 seen_set = true;
102 stream.next();
103 parse_set_clause(&mut stream, &mut set_actions).map_err(wrap_syntax_error)?;
104 }
105 Some(Token::Remove) => {
106 if seen_remove {
107 return Err("Invalid UpdateExpression: The \"REMOVE\" section can only be used once in an update expression;".to_string());
108 }
109 seen_remove = true;
110 stream.next();
111 parse_remove_clause(&mut stream, &mut remove_actions).map_err(wrap_syntax_error)?;
112 }
113 Some(Token::Add) => {
114 if seen_add {
115 return Err("Invalid UpdateExpression: The \"ADD\" section can only be used once in an update expression;".to_string());
116 }
117 seen_add = true;
118 stream.next();
119 parse_add_clause(&mut stream, &mut add_actions).map_err(wrap_syntax_error)?;
120 }
121 Some(Token::Delete) => {
122 if seen_delete {
123 return Err("Invalid UpdateExpression: The \"DELETE\" section can only be used once in an update expression;".to_string());
124 }
125 seen_delete = true;
126 stream.next();
127 parse_delete_clause(&mut stream, &mut delete_actions).map_err(wrap_syntax_error)?;
128 }
129 Some(_) => {
130 let offending_span = stream
134 .peek_span()
135 .expect("peek_span must yield when peek did");
136 let bad = &expr[offending_span.start..offending_span.end()];
137 stream.next();
138 let next_span = stream.peek_span();
139 let near = near_window_parser(expr, offending_span, next_span);
140 return Err(format!(
141 r#"Invalid UpdateExpression: Syntax error; token: "{bad}", near: "{near}""#
142 ));
143 }
144 None => break,
145 }
146 }
147
148 Ok(UpdateExpr {
149 set_actions,
150 remove_actions,
151 add_actions,
152 delete_actions,
153 })
154}
155
156fn wrap_syntax_error(err: String) -> String {
159 if err.starts_with("Invalid UpdateExpression:") {
160 err
161 } else if err.starts_with("Attribute name is a reserved keyword") {
162 format!("Invalid UpdateExpression: {err}")
163 } else {
164 format!("Invalid UpdateExpression: Syntax error; {err}")
165 }
166}
167
168pub fn track_references(
173 expr: &UpdateExpr,
174 tracker: &TrackedExpressionAttributes,
175) -> Result<(), String> {
176 let mut all_target_paths: Vec<Vec<PathElement>> = Vec::new();
178
179 for action in &expr.set_actions {
180 track_path_refs(&action.path, tracker)?;
181 track_set_value_refs(&action.value, tracker)?;
182 all_target_paths.push(resolve_tracked_path(&action.path, tracker));
183 }
184 for path in &expr.remove_actions {
185 track_path_refs(path, tracker)?;
186 all_target_paths.push(resolve_tracked_path(path, tracker));
187 }
188 for action in &expr.add_actions {
189 track_path_refs(&action.path, tracker)?;
190 let val = tracker.resolve_value(&action.value_ref)?;
191 validate_add_type(val)?;
193 all_target_paths.push(resolve_tracked_path(&action.path, tracker));
194 }
195 for action in &expr.delete_actions {
196 track_path_refs(&action.path, tracker)?;
197 let val = tracker.resolve_value(&action.value_ref)?;
198 validate_delete_type(val)?;
200 all_target_paths.push(resolve_tracked_path(&action.path, tracker));
201 }
202
203 for action in &expr.set_actions {
205 validate_set_value_types(&action.value, tracker)?;
206 }
207
208 check_path_overlaps(&all_target_paths)?;
210
211 Ok(())
212}
213
214fn validate_add_type(val: &crate::types::AttributeValue) -> Result<(), String> {
216 use crate::types::AttributeValue;
217 match val {
218 AttributeValue::N(_)
219 | AttributeValue::SS(_)
220 | AttributeValue::NS(_)
221 | AttributeValue::BS(_) => Ok(()),
222 _ => Err(format!(
223 "Invalid UpdateExpression: Incorrect operand type for operator or function; \
224 operator: ADD, operand type: {}",
225 dynamo_type_name(val)
226 )),
227 }
228}
229
230fn validate_delete_type(val: &crate::types::AttributeValue) -> Result<(), String> {
232 use crate::types::AttributeValue;
233 match val {
234 AttributeValue::SS(_) | AttributeValue::NS(_) | AttributeValue::BS(_) => Ok(()),
235 _ => Err(format!(
236 "Invalid UpdateExpression: Incorrect operand type for operator or function; \
237 operator: DELETE, operand type: {}",
238 dynamo_type_name(val)
239 )),
240 }
241}
242
243fn dynamo_type_name(val: &crate::types::AttributeValue) -> &'static str {
245 use crate::types::AttributeValue;
246 match val {
247 AttributeValue::S(_) => "STRING",
248 AttributeValue::N(_) => "NUMBER",
249 AttributeValue::B(_) => "BINARY",
250 AttributeValue::BOOL(_) => "BOOLEAN",
251 AttributeValue::NULL(_) => "NULL",
252 AttributeValue::SS(_) => "SS",
253 AttributeValue::NS(_) => "NS",
254 AttributeValue::BS(_) => "BS",
255 AttributeValue::L(_) => "LIST",
256 AttributeValue::M(_) => "MAP",
257 }
258}
259
260fn validate_set_value_types(
262 value: &SetValue,
263 tracker: &TrackedExpressionAttributes,
264) -> Result<(), String> {
265 match value {
266 SetValue::Operand(op) => validate_set_operand_types(op, tracker),
267 SetValue::Plus(left, right) => {
268 validate_arithmetic_operand(left, "+", tracker)?;
269 validate_arithmetic_operand(right, "+", tracker)
270 }
271 SetValue::Minus(left, right) => {
272 validate_arithmetic_operand(left, "-", tracker)?;
273 validate_arithmetic_operand(right, "-", tracker)
274 }
275 }
276}
277
278fn validate_arithmetic_operand(
280 operand: &SetOperand,
281 op: &str,
282 tracker: &TrackedExpressionAttributes,
283) -> Result<(), String> {
284 use crate::types::AttributeValue;
285 match operand {
286 SetOperand::ValueRef(name) => {
287 let val = tracker.resolve_value(name)?;
288 if !matches!(val, AttributeValue::N(_)) {
289 return Err(format!(
290 "Invalid UpdateExpression: Incorrect operand type for operator or function; \
291 operator or function: {op}, operand type: {}",
292 dynamo_type_name(val)
293 ));
294 }
295 Ok(())
296 }
297 SetOperand::IfNotExists(_, default) => validate_set_operand_types(default, tracker),
298 SetOperand::ListAppend(a, b) => {
299 validate_list_append_operand(a, tracker)?;
300 validate_list_append_operand(b, tracker)
301 }
302 SetOperand::Path(_) => Ok(()), SetOperand::Group(inner) => validate_set_value_types(inner, tracker),
306 }
307}
308
309fn validate_set_operand_types(
311 operand: &SetOperand,
312 tracker: &TrackedExpressionAttributes,
313) -> Result<(), String> {
314 match operand {
315 SetOperand::ListAppend(a, b) => {
316 validate_list_append_operand(a, tracker)?;
317 validate_list_append_operand(b, tracker)
318 }
319 SetOperand::IfNotExists(_, default) => validate_set_operand_types(default, tracker),
320 SetOperand::Group(inner) => validate_set_value_types(inner, tracker),
321 _ => Ok(()),
322 }
323}
324
325fn validate_list_append_operand(
327 operand: &SetOperand,
328 tracker: &TrackedExpressionAttributes,
329) -> Result<(), String> {
330 use crate::types::AttributeValue;
331 if let SetOperand::ValueRef(name) = operand {
332 let val = tracker.resolve_value(name)?;
333 if !matches!(val, AttributeValue::L(_)) {
334 return Err(format!(
335 "Invalid UpdateExpression: Incorrect operand type for operator or function; \
336 operator or function: list_append, operand type: {}",
337 dynamo_type_name(val)
338 ));
339 }
340 }
341 Ok(())
342}
343
344fn resolve_tracked_path(
346 path: &[PathElement],
347 tracker: &TrackedExpressionAttributes,
348) -> Vec<PathElement> {
349 path.iter()
350 .map(|elem| {
351 if let PathElement::Attribute(name) = elem {
352 if name.starts_with('#') {
353 if let Ok(resolved) = tracker.resolve_name(name) {
354 return PathElement::Attribute(resolved);
355 }
356 }
357 }
358 elem.clone()
359 })
360 .collect()
361}
362
363fn format_path_for_error(path: &[PathElement]) -> String {
365 let parts: Vec<String> = path
366 .iter()
367 .map(|elem| match elem {
368 PathElement::Attribute(name) => name.clone(),
369 PathElement::Index(i) => format!("[{i}]"),
370 })
371 .collect();
372 format!("[{}]", parts.join(", "))
373}
374
375fn check_path_overlaps(paths: &[Vec<PathElement>]) -> Result<(), String> {
381 for i in 0..paths.len() {
382 for j in (i + 1)..paths.len() {
383 let a = &paths[i];
384 let b = &paths[j];
385 let min_len = a.len().min(b.len());
386
387 let mut common = 0;
389 for k in 0..min_len {
390 if a[k] == b[k] {
391 common += 1;
392 } else {
393 break;
394 }
395 }
396
397 if common == 0 {
398 continue;
399 }
400
401 if common == a.len() || common == b.len() {
403 let (shorter, longer) = if a.len() <= b.len() { (a, b) } else { (b, a) };
404 return Err(format!(
405 "Invalid UpdateExpression: Two document paths overlap with each other; \
406 must remove or rewrite one of these paths; \
407 path one: {}, path two: {}",
408 format_path_for_error(longer),
409 format_path_for_error(shorter)
410 ));
411 }
412
413 if common > 0 && common < min_len && a == b {
415 return Err(format!(
416 "Invalid UpdateExpression: Two document paths conflict with each other; \
417 must remove or rewrite one of these paths; \
418 path one: {}, path two: {}",
419 format_path_for_error(a),
420 format_path_for_error(b)
421 ));
422 }
423 }
424 }
425 Ok(())
426}
427
428fn track_path_refs(
429 path: &[PathElement],
430 tracker: &TrackedExpressionAttributes,
431) -> Result<(), String> {
432 for elem in path {
433 if let PathElement::Attribute(name) = elem {
434 if name.starts_with('#') {
435 tracker.resolve_name(name)?;
436 }
437 }
438 }
439 Ok(())
440}
441
442fn track_set_value_refs(
443 value: &SetValue,
444 tracker: &TrackedExpressionAttributes,
445) -> Result<(), String> {
446 match value {
447 SetValue::Operand(op) => track_set_operand_refs(op, tracker),
448 SetValue::Plus(left, right) | SetValue::Minus(left, right) => {
449 track_set_operand_refs(left, tracker)?;
450 track_set_operand_refs(right, tracker)
451 }
452 }
453}
454
455fn track_set_operand_refs(
456 operand: &SetOperand,
457 tracker: &TrackedExpressionAttributes,
458) -> Result<(), String> {
459 match operand {
460 SetOperand::Path(path) => track_path_refs(path, tracker),
461 SetOperand::ValueRef(name) => {
462 tracker.resolve_value(name)?;
463 Ok(())
464 }
465 SetOperand::IfNotExists(path, default) => {
466 track_path_refs(path, tracker)?;
467 track_set_operand_refs(default, tracker)
468 }
469 SetOperand::ListAppend(a, b) => {
470 track_set_operand_refs(a, tracker)?;
471 track_set_operand_refs(b, tracker)
472 }
473 SetOperand::Group(inner) => track_set_value_refs(inner, tracker),
474 }
475}
476
477pub fn apply(
479 item: &mut HashMap<String, AttributeValue>,
480 expr: &UpdateExpr,
481 tracker: &TrackedExpressionAttributes,
482) -> Result<(), String> {
483 let snapshot = item.clone();
492 for action in &expr.set_actions {
493 let resolved_path = resolve_path_elements(&action.path, tracker)?;
494 let value = evaluate_set_value(&action.value, &snapshot, tracker)?;
495 set_path(item, &resolved_path, value)?;
496 }
497
498 for path in &expr.remove_actions {
500 let resolved_path = resolve_path_elements(path, tracker)?;
501 remove_path(item, &resolved_path)?;
502 }
503
504 for action in &expr.add_actions {
506 let resolved_path = resolve_path_elements(&action.path, tracker)?;
507 let add_val = tracker.resolve_value(&action.value_ref)?.clone();
508 apply_add(item, &resolved_path, &add_val).map_err(|_| {
509 "An operand in the update expression has an incorrect data type".to_string()
510 })?;
511 }
512
513 for action in &expr.delete_actions {
515 let resolved_path = resolve_path_elements(&action.path, tracker)?;
516 let del_val = tracker.resolve_value(&action.value_ref)?.clone();
517 apply_delete(item, &resolved_path, &del_val).map_err(|_| {
518 "An operand in the update expression has an incorrect data type".to_string()
519 })?;
520 }
521
522 Ok(())
523}
524
525fn evaluate_set_value(
530 value: &SetValue,
531 item: &HashMap<String, AttributeValue>,
532 tracker: &TrackedExpressionAttributes,
533) -> Result<AttributeValue, String> {
534 match value {
535 SetValue::Operand(op) => evaluate_set_operand(op, item, tracker),
536 SetValue::Plus(left, right) => {
537 let lv = evaluate_set_operand(left, item, tracker)?;
538 let rv = evaluate_set_operand(right, item, tracker)?;
539 match (&lv, &rv) {
540 (AttributeValue::N(a), AttributeValue::N(b)) => {
541 use bigdecimal::BigDecimal;
542 use std::str::FromStr;
543 let da = BigDecimal::from_str(a).map_err(|_| format!("Invalid number: {a}"))?;
544 let db = BigDecimal::from_str(b).map_err(|_| format!("Invalid number: {b}"))?;
545 let result = &da + &db;
546 Ok(AttributeValue::N(format_number(&result)))
547 }
548 _ => Err("Operands for + must be numbers".to_string()),
549 }
550 }
551 SetValue::Minus(left, right) => {
552 let lv = evaluate_set_operand(left, item, tracker)?;
553 let rv = evaluate_set_operand(right, item, tracker)?;
554 match (&lv, &rv) {
555 (AttributeValue::N(a), AttributeValue::N(b)) => {
556 use bigdecimal::BigDecimal;
557 use std::str::FromStr;
558 let da = BigDecimal::from_str(a).map_err(|_| format!("Invalid number: {a}"))?;
559 let db = BigDecimal::from_str(b).map_err(|_| format!("Invalid number: {b}"))?;
560 let result = &da - &db;
561 Ok(AttributeValue::N(format_number(&result)))
562 }
563 _ => Err("Operands for - must be numbers".to_string()),
564 }
565 }
566 }
567}
568
569fn evaluate_set_operand(
570 operand: &SetOperand,
571 item: &HashMap<String, AttributeValue>,
572 tracker: &TrackedExpressionAttributes,
573) -> Result<AttributeValue, String> {
574 match operand {
575 SetOperand::Path(path) => {
576 let resolved = resolve_path_elements(path, tracker)?;
577 resolve_path(item, &resolved).ok_or_else(|| {
578 "The provided expression refers to an attribute that does not exist in the item"
579 .to_string()
580 })
581 }
582 SetOperand::ValueRef(name) => Ok(tracker.resolve_value(name)?.clone()),
583 SetOperand::IfNotExists(path, default) => {
584 let resolved = resolve_path_elements(path, tracker)?;
585 match resolve_path(item, &resolved) {
586 Some(existing) => Ok(existing),
587 None => evaluate_set_operand(default, item, tracker),
588 }
589 }
590 SetOperand::ListAppend(list1, list2) => {
591 let v1 = evaluate_set_operand(list1, item, tracker)?;
592 let v2 = evaluate_set_operand(list2, item, tracker)?;
593 match (v1, v2) {
594 (AttributeValue::L(mut a), AttributeValue::L(b)) => {
595 a.extend(b);
596 Ok(AttributeValue::L(a))
597 }
598 _ => Err("list_append requires two list operands".to_string()),
599 }
600 }
601 SetOperand::Group(inner) => evaluate_set_value(inner, item, tracker),
602 }
603}
604
605pub fn apply_add_public(
611 item: &mut HashMap<String, AttributeValue>,
612 path: &[PathElement],
613 add_val: &AttributeValue,
614) -> Result<(), String> {
615 apply_add(item, path, add_val)
616}
617
618fn apply_add(
619 item: &mut HashMap<String, AttributeValue>,
620 path: &[PathElement],
621 add_val: &AttributeValue,
622) -> Result<(), String> {
623 let existing = resolve_path(item, path);
624
625 match (existing, add_val) {
626 (Some(AttributeValue::N(existing_n)), AttributeValue::N(add_n)) => {
628 use bigdecimal::BigDecimal;
629 use std::str::FromStr;
630 let de = BigDecimal::from_str(&existing_n)
631 .map_err(|_| format!("Invalid number: {existing_n}"))?;
632 let da = BigDecimal::from_str(add_n).map_err(|_| format!("Invalid number: {add_n}"))?;
633 let result = &de + &da;
634 set_path(item, path, AttributeValue::N(format_number(&result)))
635 }
636 (None, AttributeValue::N(_)) => {
637 set_path(item, path, add_val.clone())
639 }
640
641 (Some(AttributeValue::SS(mut existing_set)), AttributeValue::SS(add_set)) => {
643 for s in add_set {
644 if !existing_set.contains(s) {
645 existing_set.push(s.clone());
646 }
647 }
648 set_path(item, path, AttributeValue::SS(existing_set))
649 }
650 (None, AttributeValue::SS(_)) => set_path(item, path, add_val.clone()),
651
652 (Some(AttributeValue::NS(mut existing_set)), AttributeValue::NS(add_set)) => {
654 for n in add_set {
655 if !existing_set.contains(n) {
656 existing_set.push(n.clone());
657 }
658 }
659 set_path(item, path, AttributeValue::NS(existing_set))
660 }
661 (None, AttributeValue::NS(_)) => set_path(item, path, add_val.clone()),
662
663 (Some(AttributeValue::BS(mut existing_set)), AttributeValue::BS(add_set)) => {
665 for b in add_set {
666 if !existing_set.contains(b) {
667 existing_set.push(b.clone());
668 }
669 }
670 set_path(item, path, AttributeValue::BS(existing_set))
671 }
672 (None, AttributeValue::BS(_)) => set_path(item, path, add_val.clone()),
673
674 (Some(AttributeValue::L(mut existing_list)), AttributeValue::L(add_list)) => {
676 existing_list.extend(add_list.iter().cloned());
677 set_path(item, path, AttributeValue::L(existing_list))
678 }
679 (None, AttributeValue::L(_)) => set_path(item, path, add_val.clone()),
680
681 _ => Err("Type mismatch for attribute to update".to_string()),
682 }
683}
684
685pub fn apply_delete_public(
691 item: &mut HashMap<String, AttributeValue>,
692 path: &[PathElement],
693 del_val: &AttributeValue,
694) -> Result<(), String> {
695 apply_delete(item, path, del_val)
696}
697
698fn apply_delete(
699 item: &mut HashMap<String, AttributeValue>,
700 path: &[PathElement],
701 del_val: &AttributeValue,
702) -> Result<(), String> {
703 let existing = resolve_path(item, path);
704
705 match (existing, del_val) {
706 (Some(AttributeValue::SS(existing_set)), AttributeValue::SS(del_set)) => {
707 let new_set: Vec<String> = existing_set
708 .into_iter()
709 .filter(|s| !del_set.contains(s))
710 .collect();
711 if new_set.is_empty() {
712 remove_path(item, path)
713 } else {
714 set_path(item, path, AttributeValue::SS(new_set))
715 }
716 }
717 (Some(AttributeValue::NS(existing_set)), AttributeValue::NS(del_set)) => {
718 let new_set: Vec<String> = existing_set
719 .into_iter()
720 .filter(|n| !del_set.contains(n))
721 .collect();
722 if new_set.is_empty() {
723 remove_path(item, path)
724 } else {
725 set_path(item, path, AttributeValue::NS(new_set))
726 }
727 }
728 (Some(AttributeValue::BS(existing_set)), AttributeValue::BS(del_set)) => {
729 let new_set: Vec<Vec<u8>> = existing_set
730 .into_iter()
731 .filter(|b| !del_set.contains(b))
732 .collect();
733 if new_set.is_empty() {
734 remove_path(item, path)
735 } else {
736 set_path(item, path, AttributeValue::BS(new_set))
737 }
738 }
739 (None, _) => Ok(()), _ => Err("Type mismatch for attribute to update".to_string()),
741 }
742}
743
744fn parse_set_clause(stream: &mut TokenStream, actions: &mut Vec<SetAction>) -> Result<(), String> {
749 actions.push(parse_set_action(stream)?);
750 while matches!(stream.peek(), Some(Token::Comma)) {
751 stream.next();
752 actions.push(parse_set_action(stream)?);
753 }
754 Ok(())
755}
756
757fn parse_set_action(stream: &mut TokenStream) -> Result<SetAction, String> {
758 let path = parse_raw_path(stream)?;
759 stream.expect(&Token::Eq)?;
760 let value = parse_set_value(stream)?;
761 Ok(SetAction { path, value })
762}
763
764fn parse_set_value(stream: &mut TokenStream) -> Result<SetValue, String> {
765 let left = parse_set_operand(stream)?;
766
767 match stream.peek() {
768 Some(Token::Plus) => {
769 stream.next();
770 let right = parse_set_operand(stream)?;
771 Ok(SetValue::Plus(left, right))
772 }
773 Some(Token::Minus) => {
774 stream.next();
775 let right = parse_set_operand(stream)?;
776 Ok(SetValue::Minus(left, right))
777 }
778 _ => Ok(SetValue::Operand(left)),
779 }
780}
781
782fn parse_set_operand(stream: &mut TokenStream) -> Result<SetOperand, String> {
783 if let Some(Token::Identifier(name)) = stream.peek() {
785 let func_name = name.to_lowercase();
786 let orig_name = name.clone();
787 match func_name.as_str() {
788 "if_not_exists" => {
789 stream.next();
790 stream.expect(&Token::LParen)?;
791
792 match stream.peek() {
794 Some(Token::ValueRef(_)) => {
795 return Err(
796 "Invalid UpdateExpression: Operator or function requires a document path; \
797 operator or function: if_not_exists".to_string()
798 );
799 }
800 Some(Token::Identifier(fname))
801 if fname.to_lowercase() == "if_not_exists"
802 || fname.to_lowercase() == "list_append" =>
803 {
804 return Err(
805 "Invalid UpdateExpression: Operator or function requires a document path; \
806 operator or function: if_not_exists".to_string()
807 );
808 }
809 _ => {}
810 }
811
812 let path = parse_raw_path(stream)?;
813
814 if !matches!(stream.peek(), Some(Token::Comma)) {
816 return Err(
817 "Invalid UpdateExpression: Incorrect number of operands for operator or function; \
818 operator or function: if_not_exists, number of operands: 1".to_string()
819 );
820 }
821 stream.expect(&Token::Comma)?;
822 let default = parse_set_operand(stream)?;
823 stream.expect(&Token::RParen)?;
824 return Ok(SetOperand::IfNotExists(path, Box::new(default)));
825 }
826 "list_append" => {
827 stream.next();
828 stream.expect(&Token::LParen)?;
829 let list1 = parse_set_operand(stream)?;
830
831 if !matches!(stream.peek(), Some(Token::Comma)) {
833 return Err(
834 "Invalid UpdateExpression: Incorrect number of operands for operator or function; \
835 operator or function: list_append, number of operands: 1".to_string()
836 );
837 }
838 stream.expect(&Token::Comma)?;
839 let list2 = parse_set_operand(stream)?;
840 stream.expect(&Token::RParen)?;
841 return Ok(SetOperand::ListAppend(Box::new(list1), Box::new(list2)));
842 }
843 _ => {
844 let saved_pos = stream.pos();
847 stream.next();
848 if matches!(stream.peek(), Some(Token::LParen)) {
849 return Err(format!(
850 "Invalid UpdateExpression: Invalid function name; function: {}",
851 orig_name
852 ));
853 }
854 stream.set_pos(saved_pos);
856 }
857 }
858 }
859
860 match stream.peek() {
861 Some(Token::LParen) => {
864 stream.next();
865 let inner = parse_set_value(stream)?;
866 stream.expect(&Token::RParen)?;
867 Ok(SetOperand::Group(Box::new(inner)))
868 }
869 Some(Token::ValueRef(_)) => {
870 if let Some(Token::ValueRef(name)) = stream.next().cloned() {
871 Ok(SetOperand::ValueRef(name))
872 } else {
873 unreachable!()
874 }
875 }
876 Some(Token::Identifier(_)) | Some(Token::NameRef(_)) => {
877 let path = parse_raw_path(stream)?;
878 Ok(SetOperand::Path(path))
879 }
880 Some(t) => Err(format!("Expected operand in SET, got {t}")),
881 None => Err("Expected operand in SET, got end of expression".to_string()),
882 }
883}
884
885fn parse_remove_clause(
886 stream: &mut TokenStream,
887 actions: &mut Vec<Vec<PathElement>>,
888) -> Result<(), String> {
889 actions.push(parse_raw_path(stream)?);
890 while matches!(stream.peek(), Some(Token::Comma)) {
891 stream.next();
892 actions.push(parse_raw_path(stream)?);
893 }
894 Ok(())
895}
896
897fn parse_add_clause(stream: &mut TokenStream, actions: &mut Vec<AddAction>) -> Result<(), String> {
898 actions.push(parse_add_action(stream)?);
899 while matches!(stream.peek(), Some(Token::Comma)) {
900 stream.next();
901 actions.push(parse_add_action(stream)?);
902 }
903 Ok(())
904}
905
906fn parse_add_action(stream: &mut TokenStream) -> Result<AddAction, String> {
907 let path = parse_raw_path(stream)?;
908 match stream.next() {
909 Some(Token::ValueRef(name)) => Ok(AddAction {
910 path,
911 value_ref: name.clone(),
912 }),
913 Some(t) => Err(format!("Expected value reference in ADD, got {t}")),
914 None => Err("Expected value reference in ADD, got end of expression".to_string()),
915 }
916}
917
918fn parse_delete_clause(
919 stream: &mut TokenStream,
920 actions: &mut Vec<DeleteAction>,
921) -> Result<(), String> {
922 actions.push(parse_delete_action(stream)?);
923 while matches!(stream.peek(), Some(Token::Comma)) {
924 stream.next();
925 actions.push(parse_delete_action(stream)?);
926 }
927 Ok(())
928}
929
930fn parse_delete_action(stream: &mut TokenStream) -> Result<DeleteAction, String> {
931 let path = parse_raw_path(stream)?;
932 match stream.next() {
933 Some(Token::ValueRef(name)) => Ok(DeleteAction {
934 path,
935 value_ref: name.clone(),
936 }),
937 Some(t) => Err(format!("Expected value reference in DELETE, got {t}")),
938 None => Err("Expected value reference in DELETE, got end of expression".to_string()),
939 }
940}
941
942fn format_number(n: &bigdecimal::BigDecimal) -> String {
945 let normalized = n.normalized();
946 if normalized.as_bigint_and_exponent().1 < 0 {
950 normalized.with_scale(0).to_string()
951 } else {
952 normalized.to_string()
953 }
954}
955
956#[cfg(test)]
957mod tests {
958 use super::*;
959
960 fn make_item(pairs: &[(&str, AttributeValue)]) -> HashMap<String, AttributeValue> {
961 pairs
962 .iter()
963 .map(|(k, v)| (k.to_string(), v.clone()))
964 .collect()
965 }
966
967 fn vals(pairs: &[(&str, AttributeValue)]) -> Option<HashMap<String, AttributeValue>> {
968 Some(make_item(pairs))
969 }
970
971 fn make_tracker<'a>(
972 names: &'a Option<HashMap<String, String>>,
973 values: &'a Option<HashMap<String, AttributeValue>>,
974 ) -> TrackedExpressionAttributes<'a> {
975 TrackedExpressionAttributes::new(names, values)
976 }
977
978 #[test]
979 fn test_set_simple() {
980 let expr = parse("SET label = :val").unwrap();
981 assert_eq!(expr.set_actions.len(), 1);
982 assert!(expr.remove_actions.is_empty());
983 }
984
985 #[test]
986 fn test_set_multiple() {
987 let expr = parse("SET a = :v1, b = :v2").unwrap();
988 assert_eq!(expr.set_actions.len(), 2);
989 }
990
991 #[test]
992 fn test_set_arithmetic_plus() {
993 let expr = parse("SET tally = tally + :inc").unwrap();
994 let mut item = make_item(&[
995 ("pk", AttributeValue::S("k".into())),
996 ("tally", AttributeValue::N("10".into())),
997 ]);
998 let av = vals(&[(":inc", AttributeValue::N("5".into()))]);
999 let no_names = None;
1000 let tracker = make_tracker(&no_names, &av);
1001 apply(&mut item, &expr, &tracker).unwrap();
1002 assert_eq!(item["tally"], AttributeValue::N("15".into()));
1003 }
1004
1005 #[test]
1006 fn test_set_arithmetic_minus() {
1007 let expr = parse("SET price = price - :discount").unwrap();
1008 let mut item = make_item(&[
1009 ("pk", AttributeValue::S("k".into())),
1010 ("price", AttributeValue::N("100".into())),
1011 ]);
1012 let av = vals(&[(":discount", AttributeValue::N("25".into()))]);
1013 let no_names = None;
1014 let tracker = make_tracker(&no_names, &av);
1015 apply(&mut item, &expr, &tracker).unwrap();
1016 assert_eq!(item["price"], AttributeValue::N("75".into()));
1017 }
1018
1019 #[test]
1020 fn test_set_if_not_exists() {
1021 let expr = parse("SET hits = if_not_exists(hits, :zero)").unwrap();
1022 let mut item = make_item(&[("pk", AttributeValue::S("k".into()))]);
1023 let av = vals(&[(":zero", AttributeValue::N("0".into()))]);
1024 let no_names = None;
1025 let tracker = make_tracker(&no_names, &av);
1026 apply(&mut item, &expr, &tracker).unwrap();
1027 assert_eq!(item["hits"], AttributeValue::N("0".into()));
1028
1029 let tracker2 = make_tracker(&no_names, &av);
1031 apply(&mut item, &expr, &tracker2).unwrap();
1032 assert_eq!(item["hits"], AttributeValue::N("0".into()));
1033 }
1034
1035 #[test]
1036 fn test_set_list_append() {
1037 let expr = parse("SET entries = list_append(entries, :new)").unwrap();
1038 let mut item = make_item(&[
1039 ("pk", AttributeValue::S("k".into())),
1040 (
1041 "entries",
1042 AttributeValue::L(vec![AttributeValue::S("a".into())]),
1043 ),
1044 ]);
1045 let av = vals(&[(
1046 ":new",
1047 AttributeValue::L(vec![AttributeValue::S("b".into())]),
1048 )]);
1049 let no_names = None;
1050 let tracker = make_tracker(&no_names, &av);
1051 apply(&mut item, &expr, &tracker).unwrap();
1052 if let AttributeValue::L(list) = &item["entries"] {
1053 assert_eq!(list.len(), 2);
1054 } else {
1055 panic!("Expected list");
1056 }
1057 }
1058
1059 #[test]
1060 fn test_remove() {
1061 let expr = parse("REMOVE attr1, attr2").unwrap();
1062 let mut item = make_item(&[
1063 ("pk", AttributeValue::S("k".into())),
1064 ("attr1", AttributeValue::S("a".into())),
1065 ("attr2", AttributeValue::S("b".into())),
1066 ("attr3", AttributeValue::S("c".into())),
1067 ]);
1068 let no_names = None;
1069 let no_values = None;
1070 let tracker = make_tracker(&no_names, &no_values);
1071 apply(&mut item, &expr, &tracker).unwrap();
1072 assert!(!item.contains_key("attr1"));
1073 assert!(!item.contains_key("attr2"));
1074 assert!(item.contains_key("attr3"));
1075 }
1076
1077 #[test]
1078 fn test_add_number() {
1079 let expr = parse("ADD tally :inc").unwrap();
1080 let mut item = make_item(&[
1081 ("pk", AttributeValue::S("k".into())),
1082 ("tally", AttributeValue::N("10".into())),
1083 ]);
1084 let av = vals(&[(":inc", AttributeValue::N("5".into()))]);
1085 let no_names = None;
1086 let tracker = make_tracker(&no_names, &av);
1087 apply(&mut item, &expr, &tracker).unwrap();
1088 assert_eq!(item["tally"], AttributeValue::N("15".into()));
1089 }
1090
1091 #[test]
1092 fn test_add_number_create() {
1093 let expr = parse("ADD tally :val").unwrap();
1094 let mut item = make_item(&[("pk", AttributeValue::S("k".into()))]);
1095 let av = vals(&[(":val", AttributeValue::N("1".into()))]);
1096 let no_names = None;
1097 let tracker = make_tracker(&no_names, &av);
1098 apply(&mut item, &expr, &tracker).unwrap();
1099 assert_eq!(item["tally"], AttributeValue::N("1".into()));
1100 }
1101
1102 #[test]
1103 fn test_add_string_set() {
1104 let expr = parse("ADD colors :new_colors").unwrap();
1105 let mut item = make_item(&[
1106 ("pk", AttributeValue::S("k".into())),
1107 (
1108 "colors",
1109 AttributeValue::SS(vec!["red".into(), "blue".into()]),
1110 ),
1111 ]);
1112 let av = vals(&[(
1113 ":new_colors",
1114 AttributeValue::SS(vec!["blue".into(), "green".into()]),
1115 )]);
1116 let no_names = None;
1117 let tracker = make_tracker(&no_names, &av);
1118 apply(&mut item, &expr, &tracker).unwrap();
1119 if let AttributeValue::SS(set) = &item["colors"] {
1120 assert_eq!(set.len(), 3); assert!(set.contains(&"green".to_string()));
1122 } else {
1123 panic!("Expected SS");
1124 }
1125 }
1126
1127 #[test]
1128 fn test_delete_string_set() {
1129 let expr = parse("DELETE colors :remove").unwrap();
1130 let mut item = make_item(&[
1131 ("pk", AttributeValue::S("k".into())),
1132 (
1133 "colors",
1134 AttributeValue::SS(vec!["red".into(), "blue".into(), "green".into()]),
1135 ),
1136 ]);
1137 let av = vals(&[(
1138 ":remove",
1139 AttributeValue::SS(vec!["blue".into(), "green".into()]),
1140 )]);
1141 let no_names = None;
1142 let tracker = make_tracker(&no_names, &av);
1143 apply(&mut item, &expr, &tracker).unwrap();
1144 if let AttributeValue::SS(set) = &item["colors"] {
1145 assert_eq!(set, &vec!["red".to_string()]);
1146 } else {
1147 panic!("Expected SS");
1148 }
1149 }
1150
1151 #[test]
1152 fn test_combined_set_remove() {
1153 let expr = parse("SET label = :name REMOVE old_attr").unwrap();
1154 assert_eq!(expr.set_actions.len(), 1);
1155 assert_eq!(expr.remove_actions.len(), 1);
1156 }
1157
1158 #[test]
1159 fn test_duplicate_clause_error() {
1160 let result = parse("SET a = :v SET b = :w");
1161 assert!(result.is_err());
1162 assert!(result.unwrap_err().contains("only be used once"));
1163 }
1164
1165 #[test]
1167 fn test_set_reads_pre_update_snapshot() {
1168 let expr = parse("SET a = :v, b = a").unwrap();
1169 let mut item = make_item(&[
1170 ("pk", AttributeValue::S("k".into())),
1171 ("a", AttributeValue::S("OLD".into())),
1172 ]);
1173 let av = vals(&[(":v", AttributeValue::S("NEW".into()))]);
1174 let no_names = None;
1175 let tracker = make_tracker(&no_names, &av);
1176 apply(&mut item, &expr, &tracker).unwrap();
1177 assert_eq!(item["a"], AttributeValue::S("NEW".into()));
1178 assert_eq!(item["b"], AttributeValue::S("OLD".into()));
1179 }
1180
1181 #[test]
1183 fn test_set_parenthesised_arithmetic() {
1184 let expr = parse("SET c = (c - :v)").unwrap();
1185 let mut item = make_item(&[
1186 ("pk", AttributeValue::S("k".into())),
1187 ("c", AttributeValue::N("10".into())),
1188 ]);
1189 let av = vals(&[(":v", AttributeValue::N("3".into()))]);
1190 let no_names = None;
1191 let tracker = make_tracker(&no_names, &av);
1192 apply(&mut item, &expr, &tracker).unwrap();
1193 assert_eq!(item["c"], AttributeValue::N("7".into()));
1194 }
1195
1196 #[test]
1198 fn test_set_parenthesised_arithmetic_bigdecimal() {
1199 let expr = parse("SET c = (c + :v)").unwrap();
1200 let mut item = make_item(&[
1201 ("pk", AttributeValue::S("k".into())),
1202 ("c", AttributeValue::N("100000000000000000000".into())),
1203 ]);
1204 let av = vals(&[(":v", AttributeValue::N("1".into()))]);
1205 let no_names = None;
1206 let tracker = make_tracker(&no_names, &av);
1207 apply(&mut item, &expr, &tracker).unwrap();
1208 assert_eq!(item["c"], AttributeValue::N("100000000000000000001".into()));
1209 }
1210}