1use crate::expressions::condition::parse_raw_path;
6use crate::expressions::tokenizer::{Token, TokenStream, tokenize};
7use crate::expressions::{
8 PathElement, TrackedExpressionAttributes, remove_path, resolve_path, resolve_path_elements,
9 set_path,
10};
11use crate::types::AttributeValue;
12use std::collections::HashMap;
13
14#[derive(Debug)]
16pub struct UpdateExpr {
17 pub set_actions: Vec<SetAction>,
18 pub remove_actions: Vec<Vec<PathElement>>,
19 pub add_actions: Vec<AddAction>,
20 pub delete_actions: Vec<DeleteAction>,
21}
22
23#[derive(Debug)]
25pub struct SetAction {
26 pub path: Vec<PathElement>,
27 pub value: SetValue,
28}
29
30#[derive(Debug)]
32pub enum SetValue {
33 Operand(SetOperand),
35 Plus(SetOperand, SetOperand),
37 Minus(SetOperand, SetOperand),
39}
40
41#[derive(Debug)]
43pub enum SetOperand {
44 Path(Vec<PathElement>),
45 ValueRef(String),
46 IfNotExists(Vec<PathElement>, Box<SetOperand>),
47 ListAppend(Box<SetOperand>, Box<SetOperand>),
48}
49
50#[derive(Debug)]
52pub struct AddAction {
53 pub path: Vec<PathElement>,
54 pub value_ref: String,
55}
56
57#[derive(Debug)]
59pub struct DeleteAction {
60 pub path: Vec<PathElement>,
61 pub value_ref: String,
62}
63
64pub fn parse(expr: &str) -> Result<UpdateExpr, String> {
66 let tokens =
67 tokenize(expr).map_err(|e| format!("Invalid UpdateExpression: Syntax error; {e}"))?;
68 let mut stream = TokenStream::new(tokens);
69
70 let mut set_actions = Vec::new();
71 let mut remove_actions = Vec::new();
72 let mut add_actions = Vec::new();
73 let mut delete_actions = Vec::new();
74
75 let mut seen_set = false;
76 let mut seen_remove = false;
77 let mut seen_add = false;
78 let mut seen_delete = false;
79
80 while !stream.at_end() {
81 match stream.peek() {
82 Some(Token::Set) => {
83 if seen_set {
84 return Err("Invalid UpdateExpression: The \"SET\" section can only be used once in an update expression;".to_string());
85 }
86 seen_set = true;
87 stream.next();
88 parse_set_clause(&mut stream, &mut set_actions).map_err(wrap_syntax_error)?;
89 }
90 Some(Token::Remove) => {
91 if seen_remove {
92 return Err("Invalid UpdateExpression: The \"REMOVE\" section can only be used once in an update expression;".to_string());
93 }
94 seen_remove = true;
95 stream.next();
96 parse_remove_clause(&mut stream, &mut remove_actions).map_err(wrap_syntax_error)?;
97 }
98 Some(Token::Add) => {
99 if seen_add {
100 return Err("Invalid UpdateExpression: The \"ADD\" section can only be used once in an update expression;".to_string());
101 }
102 seen_add = true;
103 stream.next();
104 parse_add_clause(&mut stream, &mut add_actions).map_err(wrap_syntax_error)?;
105 }
106 Some(Token::Delete) => {
107 if seen_delete {
108 return Err("Invalid UpdateExpression: The \"DELETE\" section can only be used once in an update expression;".to_string());
109 }
110 seen_delete = true;
111 stream.next();
112 parse_delete_clause(&mut stream, &mut delete_actions).map_err(wrap_syntax_error)?;
113 }
114 Some(t) => {
115 return Err(format!(
116 "Invalid UpdateExpression: Syntax error; token: {t}"
117 ));
118 }
119 None => break,
120 }
121 }
122
123 Ok(UpdateExpr {
124 set_actions,
125 remove_actions,
126 add_actions,
127 delete_actions,
128 })
129}
130
131fn wrap_syntax_error(err: String) -> String {
134 if err.starts_with("Invalid UpdateExpression:") {
135 err
136 } else if err.starts_with("Attribute name is a reserved keyword") {
137 format!("Invalid UpdateExpression: {err}")
138 } else {
139 format!("Invalid UpdateExpression: Syntax error; {err}")
140 }
141}
142
143pub fn track_references(
148 expr: &UpdateExpr,
149 tracker: &TrackedExpressionAttributes,
150) -> Result<(), String> {
151 let mut all_target_paths: Vec<Vec<PathElement>> = Vec::new();
153
154 for action in &expr.set_actions {
155 track_path_refs(&action.path, tracker)?;
156 track_set_value_refs(&action.value, tracker)?;
157 all_target_paths.push(resolve_tracked_path(&action.path, tracker));
158 }
159 for path in &expr.remove_actions {
160 track_path_refs(path, tracker)?;
161 all_target_paths.push(resolve_tracked_path(path, tracker));
162 }
163 for action in &expr.add_actions {
164 track_path_refs(&action.path, tracker)?;
165 let val = tracker.resolve_value(&action.value_ref)?;
166 validate_add_type(val)?;
168 all_target_paths.push(resolve_tracked_path(&action.path, tracker));
169 }
170 for action in &expr.delete_actions {
171 track_path_refs(&action.path, tracker)?;
172 let val = tracker.resolve_value(&action.value_ref)?;
173 validate_delete_type(val)?;
175 all_target_paths.push(resolve_tracked_path(&action.path, tracker));
176 }
177
178 for action in &expr.set_actions {
180 validate_set_value_types(&action.value, tracker)?;
181 }
182
183 check_path_overlaps(&all_target_paths)?;
185
186 Ok(())
187}
188
189fn validate_add_type(val: &crate::types::AttributeValue) -> Result<(), String> {
191 use crate::types::AttributeValue;
192 match val {
193 AttributeValue::N(_)
194 | AttributeValue::SS(_)
195 | AttributeValue::NS(_)
196 | AttributeValue::BS(_) => Ok(()),
197 _ => Err(format!(
198 "Invalid UpdateExpression: Incorrect operand type for operator or function; \
199 operator: ADD, operand type: {}",
200 dynamo_type_name(val)
201 )),
202 }
203}
204
205fn validate_delete_type(val: &crate::types::AttributeValue) -> Result<(), String> {
207 use crate::types::AttributeValue;
208 match val {
209 AttributeValue::SS(_) | AttributeValue::NS(_) | AttributeValue::BS(_) => Ok(()),
210 _ => Err(format!(
211 "Invalid UpdateExpression: Incorrect operand type for operator or function; \
212 operator: DELETE, operand type: {}",
213 dynamo_type_name(val)
214 )),
215 }
216}
217
218fn dynamo_type_name(val: &crate::types::AttributeValue) -> &'static str {
220 use crate::types::AttributeValue;
221 match val {
222 AttributeValue::S(_) => "STRING",
223 AttributeValue::N(_) => "NUMBER",
224 AttributeValue::B(_) => "BINARY",
225 AttributeValue::BOOL(_) => "BOOLEAN",
226 AttributeValue::NULL(_) => "NULL",
227 AttributeValue::SS(_) => "SS",
228 AttributeValue::NS(_) => "NS",
229 AttributeValue::BS(_) => "BS",
230 AttributeValue::L(_) => "LIST",
231 AttributeValue::M(_) => "MAP",
232 }
233}
234
235fn validate_set_value_types(
237 value: &SetValue,
238 tracker: &TrackedExpressionAttributes,
239) -> Result<(), String> {
240 match value {
241 SetValue::Operand(op) => validate_set_operand_types(op, tracker),
242 SetValue::Plus(left, right) => {
243 validate_arithmetic_operand(left, "+", tracker)?;
244 validate_arithmetic_operand(right, "+", tracker)
245 }
246 SetValue::Minus(left, right) => {
247 validate_arithmetic_operand(left, "-", tracker)?;
248 validate_arithmetic_operand(right, "-", tracker)
249 }
250 }
251}
252
253fn validate_arithmetic_operand(
255 operand: &SetOperand,
256 op: &str,
257 tracker: &TrackedExpressionAttributes,
258) -> Result<(), String> {
259 use crate::types::AttributeValue;
260 match operand {
261 SetOperand::ValueRef(name) => {
262 let val = tracker.resolve_value(name)?;
263 if !matches!(val, AttributeValue::N(_)) {
264 return Err(format!(
265 "Invalid UpdateExpression: Incorrect operand type for operator or function; \
266 operator or function: {op}, operand type: {}",
267 dynamo_type_name(val)
268 ));
269 }
270 Ok(())
271 }
272 SetOperand::IfNotExists(_, default) => validate_set_operand_types(default, tracker),
273 SetOperand::ListAppend(a, b) => {
274 validate_list_append_operand(a, tracker)?;
275 validate_list_append_operand(b, tracker)
276 }
277 SetOperand::Path(_) => Ok(()), }
279}
280
281fn validate_set_operand_types(
283 operand: &SetOperand,
284 tracker: &TrackedExpressionAttributes,
285) -> Result<(), String> {
286 match operand {
287 SetOperand::ListAppend(a, b) => {
288 validate_list_append_operand(a, tracker)?;
289 validate_list_append_operand(b, tracker)
290 }
291 SetOperand::IfNotExists(_, default) => validate_set_operand_types(default, tracker),
292 _ => Ok(()),
293 }
294}
295
296fn validate_list_append_operand(
298 operand: &SetOperand,
299 tracker: &TrackedExpressionAttributes,
300) -> Result<(), String> {
301 use crate::types::AttributeValue;
302 if let SetOperand::ValueRef(name) = operand {
303 let val = tracker.resolve_value(name)?;
304 if !matches!(val, AttributeValue::L(_)) {
305 return Err(format!(
306 "Invalid UpdateExpression: Incorrect operand type for operator or function; \
307 operator or function: list_append, operand type: {}",
308 dynamo_type_name(val)
309 ));
310 }
311 }
312 Ok(())
313}
314
315fn resolve_tracked_path(
317 path: &[PathElement],
318 tracker: &TrackedExpressionAttributes,
319) -> Vec<PathElement> {
320 path.iter()
321 .map(|elem| {
322 if let PathElement::Attribute(name) = elem {
323 if name.starts_with('#') {
324 if let Ok(resolved) = tracker.resolve_name(name) {
325 return PathElement::Attribute(resolved);
326 }
327 }
328 }
329 elem.clone()
330 })
331 .collect()
332}
333
334fn format_path_for_error(path: &[PathElement]) -> String {
336 let parts: Vec<String> = path
337 .iter()
338 .map(|elem| match elem {
339 PathElement::Attribute(name) => name.clone(),
340 PathElement::Index(i) => format!("[{i}]"),
341 })
342 .collect();
343 format!("[{}]", parts.join(", "))
344}
345
346fn check_path_overlaps(paths: &[Vec<PathElement>]) -> Result<(), String> {
352 for i in 0..paths.len() {
353 for j in (i + 1)..paths.len() {
354 let a = &paths[i];
355 let b = &paths[j];
356 let min_len = a.len().min(b.len());
357
358 let mut common = 0;
360 for k in 0..min_len {
361 if a[k] == b[k] {
362 common += 1;
363 } else {
364 break;
365 }
366 }
367
368 if common == 0 {
369 continue;
370 }
371
372 if common == a.len() || common == b.len() {
374 let (shorter, longer) = if a.len() <= b.len() { (a, b) } else { (b, a) };
375 return Err(format!(
376 "Invalid UpdateExpression: Two document paths overlap with each other; \
377 must remove or rewrite one of these paths; \
378 path one: {}, path two: {}",
379 format_path_for_error(longer),
380 format_path_for_error(shorter)
381 ));
382 }
383
384 if common > 0 && common < min_len && a == b {
386 return Err(format!(
387 "Invalid UpdateExpression: Two document paths conflict with each other; \
388 must remove or rewrite one of these paths; \
389 path one: {}, path two: {}",
390 format_path_for_error(a),
391 format_path_for_error(b)
392 ));
393 }
394 }
395 }
396 Ok(())
397}
398
399fn track_path_refs(
400 path: &[PathElement],
401 tracker: &TrackedExpressionAttributes,
402) -> Result<(), String> {
403 for elem in path {
404 if let PathElement::Attribute(name) = elem {
405 if name.starts_with('#') {
406 tracker.resolve_name(name)?;
407 }
408 }
409 }
410 Ok(())
411}
412
413fn track_set_value_refs(
414 value: &SetValue,
415 tracker: &TrackedExpressionAttributes,
416) -> Result<(), String> {
417 match value {
418 SetValue::Operand(op) => track_set_operand_refs(op, tracker),
419 SetValue::Plus(left, right) | SetValue::Minus(left, right) => {
420 track_set_operand_refs(left, tracker)?;
421 track_set_operand_refs(right, tracker)
422 }
423 }
424}
425
426fn track_set_operand_refs(
427 operand: &SetOperand,
428 tracker: &TrackedExpressionAttributes,
429) -> Result<(), String> {
430 match operand {
431 SetOperand::Path(path) => track_path_refs(path, tracker),
432 SetOperand::ValueRef(name) => {
433 tracker.resolve_value(name)?;
434 Ok(())
435 }
436 SetOperand::IfNotExists(path, default) => {
437 track_path_refs(path, tracker)?;
438 track_set_operand_refs(default, tracker)
439 }
440 SetOperand::ListAppend(a, b) => {
441 track_set_operand_refs(a, tracker)?;
442 track_set_operand_refs(b, tracker)
443 }
444 }
445}
446
447pub fn apply(
449 item: &mut HashMap<String, AttributeValue>,
450 expr: &UpdateExpr,
451 tracker: &TrackedExpressionAttributes,
452) -> Result<(), String> {
453 for action in &expr.set_actions {
455 let resolved_path = resolve_path_elements(&action.path, tracker)?;
456 let value = evaluate_set_value(&action.value, item, tracker)?;
457 set_path(item, &resolved_path, value)?;
458 }
459
460 for path in &expr.remove_actions {
462 let resolved_path = resolve_path_elements(path, tracker)?;
463 remove_path(item, &resolved_path)?;
464 }
465
466 for action in &expr.add_actions {
468 let resolved_path = resolve_path_elements(&action.path, tracker)?;
469 let add_val = tracker.resolve_value(&action.value_ref)?.clone();
470 apply_add(item, &resolved_path, &add_val).map_err(|_| {
471 "An operand in the update expression has an incorrect data type".to_string()
472 })?;
473 }
474
475 for action in &expr.delete_actions {
477 let resolved_path = resolve_path_elements(&action.path, tracker)?;
478 let del_val = tracker.resolve_value(&action.value_ref)?.clone();
479 apply_delete(item, &resolved_path, &del_val).map_err(|_| {
480 "An operand in the update expression has an incorrect data type".to_string()
481 })?;
482 }
483
484 Ok(())
485}
486
487fn evaluate_set_value(
492 value: &SetValue,
493 item: &HashMap<String, AttributeValue>,
494 tracker: &TrackedExpressionAttributes,
495) -> Result<AttributeValue, String> {
496 match value {
497 SetValue::Operand(op) => evaluate_set_operand(op, item, tracker),
498 SetValue::Plus(left, right) => {
499 let lv = evaluate_set_operand(left, item, tracker)?;
500 let rv = evaluate_set_operand(right, item, tracker)?;
501 match (&lv, &rv) {
502 (AttributeValue::N(a), AttributeValue::N(b)) => {
503 use bigdecimal::BigDecimal;
504 use std::str::FromStr;
505 let da = BigDecimal::from_str(a).map_err(|_| format!("Invalid number: {a}"))?;
506 let db = BigDecimal::from_str(b).map_err(|_| format!("Invalid number: {b}"))?;
507 let result = &da + &db;
508 Ok(AttributeValue::N(format_number(&result)))
509 }
510 _ => Err("Operands for + must be numbers".to_string()),
511 }
512 }
513 SetValue::Minus(left, right) => {
514 let lv = evaluate_set_operand(left, item, tracker)?;
515 let rv = evaluate_set_operand(right, item, tracker)?;
516 match (&lv, &rv) {
517 (AttributeValue::N(a), AttributeValue::N(b)) => {
518 use bigdecimal::BigDecimal;
519 use std::str::FromStr;
520 let da = BigDecimal::from_str(a).map_err(|_| format!("Invalid number: {a}"))?;
521 let db = BigDecimal::from_str(b).map_err(|_| format!("Invalid number: {b}"))?;
522 let result = &da - &db;
523 Ok(AttributeValue::N(format_number(&result)))
524 }
525 _ => Err("Operands for - must be numbers".to_string()),
526 }
527 }
528 }
529}
530
531fn evaluate_set_operand(
532 operand: &SetOperand,
533 item: &HashMap<String, AttributeValue>,
534 tracker: &TrackedExpressionAttributes,
535) -> Result<AttributeValue, String> {
536 match operand {
537 SetOperand::Path(path) => {
538 let resolved = resolve_path_elements(path, tracker)?;
539 resolve_path(item, &resolved).ok_or_else(|| {
540 "The provided expression refers to an attribute that does not exist in the item"
541 .to_string()
542 })
543 }
544 SetOperand::ValueRef(name) => Ok(tracker.resolve_value(name)?.clone()),
545 SetOperand::IfNotExists(path, default) => {
546 let resolved = resolve_path_elements(path, tracker)?;
547 match resolve_path(item, &resolved) {
548 Some(existing) => Ok(existing),
549 None => evaluate_set_operand(default, item, tracker),
550 }
551 }
552 SetOperand::ListAppend(list1, list2) => {
553 let v1 = evaluate_set_operand(list1, item, tracker)?;
554 let v2 = evaluate_set_operand(list2, item, tracker)?;
555 match (v1, v2) {
556 (AttributeValue::L(mut a), AttributeValue::L(b)) => {
557 a.extend(b);
558 Ok(AttributeValue::L(a))
559 }
560 _ => Err("list_append requires two list operands".to_string()),
561 }
562 }
563 }
564}
565
566pub fn apply_add_public(
572 item: &mut HashMap<String, AttributeValue>,
573 path: &[PathElement],
574 add_val: &AttributeValue,
575) -> Result<(), String> {
576 apply_add(item, path, add_val)
577}
578
579fn apply_add(
580 item: &mut HashMap<String, AttributeValue>,
581 path: &[PathElement],
582 add_val: &AttributeValue,
583) -> Result<(), String> {
584 let existing = resolve_path(item, path);
585
586 match (existing, add_val) {
587 (Some(AttributeValue::N(existing_n)), AttributeValue::N(add_n)) => {
589 use bigdecimal::BigDecimal;
590 use std::str::FromStr;
591 let de = BigDecimal::from_str(&existing_n)
592 .map_err(|_| format!("Invalid number: {existing_n}"))?;
593 let da = BigDecimal::from_str(add_n).map_err(|_| format!("Invalid number: {add_n}"))?;
594 let result = &de + &da;
595 set_path(item, path, AttributeValue::N(format_number(&result)))
596 }
597 (None, AttributeValue::N(_)) => {
598 set_path(item, path, add_val.clone())
600 }
601
602 (Some(AttributeValue::SS(mut existing_set)), AttributeValue::SS(add_set)) => {
604 for s in add_set {
605 if !existing_set.contains(s) {
606 existing_set.push(s.clone());
607 }
608 }
609 set_path(item, path, AttributeValue::SS(existing_set))
610 }
611 (None, AttributeValue::SS(_)) => set_path(item, path, add_val.clone()),
612
613 (Some(AttributeValue::NS(mut existing_set)), AttributeValue::NS(add_set)) => {
615 for n in add_set {
616 if !existing_set.contains(n) {
617 existing_set.push(n.clone());
618 }
619 }
620 set_path(item, path, AttributeValue::NS(existing_set))
621 }
622 (None, AttributeValue::NS(_)) => set_path(item, path, add_val.clone()),
623
624 (Some(AttributeValue::BS(mut existing_set)), AttributeValue::BS(add_set)) => {
626 for b in add_set {
627 if !existing_set.contains(b) {
628 existing_set.push(b.clone());
629 }
630 }
631 set_path(item, path, AttributeValue::BS(existing_set))
632 }
633 (None, AttributeValue::BS(_)) => set_path(item, path, add_val.clone()),
634
635 (Some(AttributeValue::L(mut existing_list)), AttributeValue::L(add_list)) => {
637 existing_list.extend(add_list.iter().cloned());
638 set_path(item, path, AttributeValue::L(existing_list))
639 }
640 (None, AttributeValue::L(_)) => set_path(item, path, add_val.clone()),
641
642 _ => Err("Type mismatch for attribute to update".to_string()),
643 }
644}
645
646pub fn apply_delete_public(
652 item: &mut HashMap<String, AttributeValue>,
653 path: &[PathElement],
654 del_val: &AttributeValue,
655) -> Result<(), String> {
656 apply_delete(item, path, del_val)
657}
658
659fn apply_delete(
660 item: &mut HashMap<String, AttributeValue>,
661 path: &[PathElement],
662 del_val: &AttributeValue,
663) -> Result<(), String> {
664 let existing = resolve_path(item, path);
665
666 match (existing, del_val) {
667 (Some(AttributeValue::SS(existing_set)), AttributeValue::SS(del_set)) => {
668 let new_set: Vec<String> = existing_set
669 .into_iter()
670 .filter(|s| !del_set.contains(s))
671 .collect();
672 if new_set.is_empty() {
673 remove_path(item, path)
674 } else {
675 set_path(item, path, AttributeValue::SS(new_set))
676 }
677 }
678 (Some(AttributeValue::NS(existing_set)), AttributeValue::NS(del_set)) => {
679 let new_set: Vec<String> = existing_set
680 .into_iter()
681 .filter(|n| !del_set.contains(n))
682 .collect();
683 if new_set.is_empty() {
684 remove_path(item, path)
685 } else {
686 set_path(item, path, AttributeValue::NS(new_set))
687 }
688 }
689 (Some(AttributeValue::BS(existing_set)), AttributeValue::BS(del_set)) => {
690 let new_set: Vec<Vec<u8>> = existing_set
691 .into_iter()
692 .filter(|b| !del_set.contains(b))
693 .collect();
694 if new_set.is_empty() {
695 remove_path(item, path)
696 } else {
697 set_path(item, path, AttributeValue::BS(new_set))
698 }
699 }
700 (None, _) => Ok(()), _ => Err("Type mismatch for attribute to update".to_string()),
702 }
703}
704
705fn parse_set_clause(stream: &mut TokenStream, actions: &mut Vec<SetAction>) -> Result<(), String> {
710 actions.push(parse_set_action(stream)?);
711 while matches!(stream.peek(), Some(Token::Comma)) {
712 stream.next();
713 actions.push(parse_set_action(stream)?);
714 }
715 Ok(())
716}
717
718fn parse_set_action(stream: &mut TokenStream) -> Result<SetAction, String> {
719 let path = parse_raw_path(stream)?;
720 stream.expect(&Token::Eq)?;
721 let value = parse_set_value(stream)?;
722 Ok(SetAction { path, value })
723}
724
725fn parse_set_value(stream: &mut TokenStream) -> Result<SetValue, String> {
726 let left = parse_set_operand(stream)?;
727
728 match stream.peek() {
729 Some(Token::Plus) => {
730 stream.next();
731 let right = parse_set_operand(stream)?;
732 Ok(SetValue::Plus(left, right))
733 }
734 Some(Token::Minus) => {
735 stream.next();
736 let right = parse_set_operand(stream)?;
737 Ok(SetValue::Minus(left, right))
738 }
739 _ => Ok(SetValue::Operand(left)),
740 }
741}
742
743fn parse_set_operand(stream: &mut TokenStream) -> Result<SetOperand, String> {
744 if let Some(Token::Identifier(name)) = stream.peek() {
746 let func_name = name.to_lowercase();
747 let orig_name = name.clone();
748 match func_name.as_str() {
749 "if_not_exists" => {
750 stream.next();
751 stream.expect(&Token::LParen)?;
752
753 match stream.peek() {
755 Some(Token::ValueRef(_)) => {
756 return Err(
757 "Invalid UpdateExpression: Operator or function requires a document path; \
758 operator or function: if_not_exists".to_string()
759 );
760 }
761 Some(Token::Identifier(fname))
762 if fname.to_lowercase() == "if_not_exists"
763 || fname.to_lowercase() == "list_append" =>
764 {
765 return Err(
766 "Invalid UpdateExpression: Operator or function requires a document path; \
767 operator or function: if_not_exists".to_string()
768 );
769 }
770 _ => {}
771 }
772
773 let path = parse_raw_path(stream)?;
774
775 if !matches!(stream.peek(), Some(Token::Comma)) {
777 return Err(
778 "Invalid UpdateExpression: Incorrect number of operands for operator or function; \
779 operator or function: if_not_exists, number of operands: 1".to_string()
780 );
781 }
782 stream.expect(&Token::Comma)?;
783 let default = parse_set_operand(stream)?;
784 stream.expect(&Token::RParen)?;
785 return Ok(SetOperand::IfNotExists(path, Box::new(default)));
786 }
787 "list_append" => {
788 stream.next();
789 stream.expect(&Token::LParen)?;
790 let list1 = parse_set_operand(stream)?;
791
792 if !matches!(stream.peek(), Some(Token::Comma)) {
794 return Err(
795 "Invalid UpdateExpression: Incorrect number of operands for operator or function; \
796 operator or function: list_append, number of operands: 1".to_string()
797 );
798 }
799 stream.expect(&Token::Comma)?;
800 let list2 = parse_set_operand(stream)?;
801 stream.expect(&Token::RParen)?;
802 return Ok(SetOperand::ListAppend(Box::new(list1), Box::new(list2)));
803 }
804 _ => {
805 let saved_pos = stream.pos();
808 stream.next();
809 if matches!(stream.peek(), Some(Token::LParen)) {
810 return Err(format!(
811 "Invalid UpdateExpression: Invalid function name; function: {}",
812 orig_name
813 ));
814 }
815 stream.set_pos(saved_pos);
817 }
818 }
819 }
820
821 match stream.peek() {
822 Some(Token::ValueRef(_)) => {
823 if let Some(Token::ValueRef(name)) = stream.next().cloned() {
824 Ok(SetOperand::ValueRef(name))
825 } else {
826 unreachable!()
827 }
828 }
829 Some(Token::Identifier(_)) | Some(Token::NameRef(_)) => {
830 let path = parse_raw_path(stream)?;
831 Ok(SetOperand::Path(path))
832 }
833 Some(t) => Err(format!("Expected operand in SET, got {t}")),
834 None => Err("Expected operand in SET, got end of expression".to_string()),
835 }
836}
837
838fn parse_remove_clause(
839 stream: &mut TokenStream,
840 actions: &mut Vec<Vec<PathElement>>,
841) -> Result<(), String> {
842 actions.push(parse_raw_path(stream)?);
843 while matches!(stream.peek(), Some(Token::Comma)) {
844 stream.next();
845 actions.push(parse_raw_path(stream)?);
846 }
847 Ok(())
848}
849
850fn parse_add_clause(stream: &mut TokenStream, actions: &mut Vec<AddAction>) -> Result<(), String> {
851 actions.push(parse_add_action(stream)?);
852 while matches!(stream.peek(), Some(Token::Comma)) {
853 stream.next();
854 actions.push(parse_add_action(stream)?);
855 }
856 Ok(())
857}
858
859fn parse_add_action(stream: &mut TokenStream) -> Result<AddAction, String> {
860 let path = parse_raw_path(stream)?;
861 match stream.next() {
862 Some(Token::ValueRef(name)) => Ok(AddAction {
863 path,
864 value_ref: name.clone(),
865 }),
866 Some(t) => Err(format!("Expected value reference in ADD, got {t}")),
867 None => Err("Expected value reference in ADD, got end of expression".to_string()),
868 }
869}
870
871fn parse_delete_clause(
872 stream: &mut TokenStream,
873 actions: &mut Vec<DeleteAction>,
874) -> Result<(), String> {
875 actions.push(parse_delete_action(stream)?);
876 while matches!(stream.peek(), Some(Token::Comma)) {
877 stream.next();
878 actions.push(parse_delete_action(stream)?);
879 }
880 Ok(())
881}
882
883fn parse_delete_action(stream: &mut TokenStream) -> Result<DeleteAction, String> {
884 let path = parse_raw_path(stream)?;
885 match stream.next() {
886 Some(Token::ValueRef(name)) => Ok(DeleteAction {
887 path,
888 value_ref: name.clone(),
889 }),
890 Some(t) => Err(format!("Expected value reference in DELETE, got {t}")),
891 None => Err("Expected value reference in DELETE, got end of expression".to_string()),
892 }
893}
894
895fn format_number(n: &bigdecimal::BigDecimal) -> String {
898 let normalized = n.normalized();
899 if normalized.as_bigint_and_exponent().1 < 0 {
903 normalized.with_scale(0).to_string()
904 } else {
905 normalized.to_string()
906 }
907}
908
909#[cfg(test)]
910mod tests {
911 use super::*;
912
913 fn make_item(pairs: &[(&str, AttributeValue)]) -> HashMap<String, AttributeValue> {
914 pairs
915 .iter()
916 .map(|(k, v)| (k.to_string(), v.clone()))
917 .collect()
918 }
919
920 fn vals(pairs: &[(&str, AttributeValue)]) -> Option<HashMap<String, AttributeValue>> {
921 Some(make_item(pairs))
922 }
923
924 fn make_tracker<'a>(
925 names: &'a Option<HashMap<String, String>>,
926 values: &'a Option<HashMap<String, AttributeValue>>,
927 ) -> TrackedExpressionAttributes<'a> {
928 TrackedExpressionAttributes::new(names, values)
929 }
930
931 #[test]
932 fn test_set_simple() {
933 let expr = parse("SET label = :val").unwrap();
934 assert_eq!(expr.set_actions.len(), 1);
935 assert!(expr.remove_actions.is_empty());
936 }
937
938 #[test]
939 fn test_set_multiple() {
940 let expr = parse("SET a = :v1, b = :v2").unwrap();
941 assert_eq!(expr.set_actions.len(), 2);
942 }
943
944 #[test]
945 fn test_set_arithmetic_plus() {
946 let expr = parse("SET tally = tally + :inc").unwrap();
947 let mut item = make_item(&[
948 ("pk", AttributeValue::S("k".into())),
949 ("tally", AttributeValue::N("10".into())),
950 ]);
951 let av = vals(&[(":inc", AttributeValue::N("5".into()))]);
952 let no_names = None;
953 let tracker = make_tracker(&no_names, &av);
954 apply(&mut item, &expr, &tracker).unwrap();
955 assert_eq!(item["tally"], AttributeValue::N("15".into()));
956 }
957
958 #[test]
959 fn test_set_arithmetic_minus() {
960 let expr = parse("SET price = price - :discount").unwrap();
961 let mut item = make_item(&[
962 ("pk", AttributeValue::S("k".into())),
963 ("price", AttributeValue::N("100".into())),
964 ]);
965 let av = vals(&[(":discount", AttributeValue::N("25".into()))]);
966 let no_names = None;
967 let tracker = make_tracker(&no_names, &av);
968 apply(&mut item, &expr, &tracker).unwrap();
969 assert_eq!(item["price"], AttributeValue::N("75".into()));
970 }
971
972 #[test]
973 fn test_set_if_not_exists() {
974 let expr = parse("SET hits = if_not_exists(hits, :zero)").unwrap();
975 let mut item = make_item(&[("pk", AttributeValue::S("k".into()))]);
976 let av = vals(&[(":zero", AttributeValue::N("0".into()))]);
977 let no_names = None;
978 let tracker = make_tracker(&no_names, &av);
979 apply(&mut item, &expr, &tracker).unwrap();
980 assert_eq!(item["hits"], AttributeValue::N("0".into()));
981
982 let tracker2 = make_tracker(&no_names, &av);
984 apply(&mut item, &expr, &tracker2).unwrap();
985 assert_eq!(item["hits"], AttributeValue::N("0".into()));
986 }
987
988 #[test]
989 fn test_set_list_append() {
990 let expr = parse("SET entries = list_append(entries, :new)").unwrap();
991 let mut item = make_item(&[
992 ("pk", AttributeValue::S("k".into())),
993 (
994 "entries",
995 AttributeValue::L(vec![AttributeValue::S("a".into())]),
996 ),
997 ]);
998 let av = vals(&[(
999 ":new",
1000 AttributeValue::L(vec![AttributeValue::S("b".into())]),
1001 )]);
1002 let no_names = None;
1003 let tracker = make_tracker(&no_names, &av);
1004 apply(&mut item, &expr, &tracker).unwrap();
1005 if let AttributeValue::L(list) = &item["entries"] {
1006 assert_eq!(list.len(), 2);
1007 } else {
1008 panic!("Expected list");
1009 }
1010 }
1011
1012 #[test]
1013 fn test_remove() {
1014 let expr = parse("REMOVE attr1, attr2").unwrap();
1015 let mut item = make_item(&[
1016 ("pk", AttributeValue::S("k".into())),
1017 ("attr1", AttributeValue::S("a".into())),
1018 ("attr2", AttributeValue::S("b".into())),
1019 ("attr3", AttributeValue::S("c".into())),
1020 ]);
1021 let no_names = None;
1022 let no_values = None;
1023 let tracker = make_tracker(&no_names, &no_values);
1024 apply(&mut item, &expr, &tracker).unwrap();
1025 assert!(!item.contains_key("attr1"));
1026 assert!(!item.contains_key("attr2"));
1027 assert!(item.contains_key("attr3"));
1028 }
1029
1030 #[test]
1031 fn test_add_number() {
1032 let expr = parse("ADD tally :inc").unwrap();
1033 let mut item = make_item(&[
1034 ("pk", AttributeValue::S("k".into())),
1035 ("tally", AttributeValue::N("10".into())),
1036 ]);
1037 let av = vals(&[(":inc", AttributeValue::N("5".into()))]);
1038 let no_names = None;
1039 let tracker = make_tracker(&no_names, &av);
1040 apply(&mut item, &expr, &tracker).unwrap();
1041 assert_eq!(item["tally"], AttributeValue::N("15".into()));
1042 }
1043
1044 #[test]
1045 fn test_add_number_create() {
1046 let expr = parse("ADD tally :val").unwrap();
1047 let mut item = make_item(&[("pk", AttributeValue::S("k".into()))]);
1048 let av = vals(&[(":val", AttributeValue::N("1".into()))]);
1049 let no_names = None;
1050 let tracker = make_tracker(&no_names, &av);
1051 apply(&mut item, &expr, &tracker).unwrap();
1052 assert_eq!(item["tally"], AttributeValue::N("1".into()));
1053 }
1054
1055 #[test]
1056 fn test_add_string_set() {
1057 let expr = parse("ADD colors :new_colors").unwrap();
1058 let mut item = make_item(&[
1059 ("pk", AttributeValue::S("k".into())),
1060 (
1061 "colors",
1062 AttributeValue::SS(vec!["red".into(), "blue".into()]),
1063 ),
1064 ]);
1065 let av = vals(&[(
1066 ":new_colors",
1067 AttributeValue::SS(vec!["blue".into(), "green".into()]),
1068 )]);
1069 let no_names = None;
1070 let tracker = make_tracker(&no_names, &av);
1071 apply(&mut item, &expr, &tracker).unwrap();
1072 if let AttributeValue::SS(set) = &item["colors"] {
1073 assert_eq!(set.len(), 3); assert!(set.contains(&"green".to_string()));
1075 } else {
1076 panic!("Expected SS");
1077 }
1078 }
1079
1080 #[test]
1081 fn test_delete_string_set() {
1082 let expr = parse("DELETE colors :remove").unwrap();
1083 let mut item = make_item(&[
1084 ("pk", AttributeValue::S("k".into())),
1085 (
1086 "colors",
1087 AttributeValue::SS(vec!["red".into(), "blue".into(), "green".into()]),
1088 ),
1089 ]);
1090 let av = vals(&[(
1091 ":remove",
1092 AttributeValue::SS(vec!["blue".into(), "green".into()]),
1093 )]);
1094 let no_names = None;
1095 let tracker = make_tracker(&no_names, &av);
1096 apply(&mut item, &expr, &tracker).unwrap();
1097 if let AttributeValue::SS(set) = &item["colors"] {
1098 assert_eq!(set, &vec!["red".to_string()]);
1099 } else {
1100 panic!("Expected SS");
1101 }
1102 }
1103
1104 #[test]
1105 fn test_combined_set_remove() {
1106 let expr = parse("SET label = :name REMOVE old_attr").unwrap();
1107 assert_eq!(expr.set_actions.len(), 1);
1108 assert_eq!(expr.remove_actions.len(), 1);
1109 }
1110
1111 #[test]
1112 fn test_duplicate_clause_error() {
1113 let result = parse("SET a = :v SET b = :w");
1114 assert!(result.is_err());
1115 assert!(result.unwrap_err().contains("only be used once"));
1116 }
1117}