1use crate::expressions::condition::parse_raw_path;
6use crate::expressions::tokenizer::{
7 Token, TokenStream, near_window_parser, near_window_tokenizer, tokenize,
8};
9use crate::expressions::{
10 PathElement, TrackedExpressionAttributes, remove_path, resolve_path, resolve_path_elements,
11 set_path,
12};
13use crate::types::AttributeValue;
14use std::collections::HashMap;
15
16#[derive(Debug)]
18pub struct UpdateExpr {
19 pub set_actions: Vec<SetAction>,
20 pub remove_actions: Vec<Vec<PathElement>>,
21 pub add_actions: Vec<AddAction>,
22 pub delete_actions: Vec<DeleteAction>,
23}
24
25#[derive(Debug)]
27pub struct SetAction {
28 pub path: Vec<PathElement>,
29 pub value: SetValue,
30}
31
32#[derive(Debug)]
34pub enum SetValue {
35 Operand(SetOperand),
37 Plus(SetOperand, SetOperand),
39 Minus(SetOperand, SetOperand),
41}
42
43#[derive(Debug)]
45pub enum SetOperand {
46 Path(Vec<PathElement>),
47 ValueRef(String),
48 IfNotExists(Vec<PathElement>, Box<SetOperand>),
49 ListAppend(Box<SetOperand>, Box<SetOperand>),
50}
51
52#[derive(Debug)]
54pub struct AddAction {
55 pub path: Vec<PathElement>,
56 pub value_ref: String,
57}
58
59#[derive(Debug)]
61pub struct DeleteAction {
62 pub path: Vec<PathElement>,
63 pub value_ref: String,
64}
65
66pub fn parse(expr: &str) -> Result<UpdateExpr, String> {
68 let tokens = match tokenize(expr) {
69 Ok(t) => t,
70 Err(err) => {
71 let bad = &expr[err.position..err.position + err.bad_len];
75 let near = near_window_tokenizer(expr, err.position);
76 return Err(format!(
77 r#"Invalid UpdateExpression: Syntax error; token: "{bad}", near: "{near}""#
78 ));
79 }
80 };
81 let mut stream = TokenStream::new(tokens);
82
83 let mut set_actions = Vec::new();
84 let mut remove_actions = Vec::new();
85 let mut add_actions = Vec::new();
86 let mut delete_actions = Vec::new();
87
88 let mut seen_set = false;
89 let mut seen_remove = false;
90 let mut seen_add = false;
91 let mut seen_delete = false;
92
93 while !stream.at_end() {
94 match stream.peek() {
95 Some(Token::Set) => {
96 if seen_set {
97 return Err("Invalid UpdateExpression: The \"SET\" section can only be used once in an update expression;".to_string());
98 }
99 seen_set = true;
100 stream.next();
101 parse_set_clause(&mut stream, &mut set_actions).map_err(wrap_syntax_error)?;
102 }
103 Some(Token::Remove) => {
104 if seen_remove {
105 return Err("Invalid UpdateExpression: The \"REMOVE\" section can only be used once in an update expression;".to_string());
106 }
107 seen_remove = true;
108 stream.next();
109 parse_remove_clause(&mut stream, &mut remove_actions).map_err(wrap_syntax_error)?;
110 }
111 Some(Token::Add) => {
112 if seen_add {
113 return Err("Invalid UpdateExpression: The \"ADD\" section can only be used once in an update expression;".to_string());
114 }
115 seen_add = true;
116 stream.next();
117 parse_add_clause(&mut stream, &mut add_actions).map_err(wrap_syntax_error)?;
118 }
119 Some(Token::Delete) => {
120 if seen_delete {
121 return Err("Invalid UpdateExpression: The \"DELETE\" section can only be used once in an update expression;".to_string());
122 }
123 seen_delete = true;
124 stream.next();
125 parse_delete_clause(&mut stream, &mut delete_actions).map_err(wrap_syntax_error)?;
126 }
127 Some(_) => {
128 let offending_span = stream
132 .peek_span()
133 .expect("peek_span must yield when peek did");
134 let bad = &expr[offending_span.start..offending_span.end()];
135 stream.next();
136 let next_span = stream.peek_span();
137 let near = near_window_parser(expr, offending_span, next_span);
138 return Err(format!(
139 r#"Invalid UpdateExpression: Syntax error; token: "{bad}", near: "{near}""#
140 ));
141 }
142 None => break,
143 }
144 }
145
146 Ok(UpdateExpr {
147 set_actions,
148 remove_actions,
149 add_actions,
150 delete_actions,
151 })
152}
153
154fn wrap_syntax_error(err: String) -> String {
157 if err.starts_with("Invalid UpdateExpression:") {
158 err
159 } else if err.starts_with("Attribute name is a reserved keyword") {
160 format!("Invalid UpdateExpression: {err}")
161 } else {
162 format!("Invalid UpdateExpression: Syntax error; {err}")
163 }
164}
165
166pub fn track_references(
171 expr: &UpdateExpr,
172 tracker: &TrackedExpressionAttributes,
173) -> Result<(), String> {
174 let mut all_target_paths: Vec<Vec<PathElement>> = Vec::new();
176
177 for action in &expr.set_actions {
178 track_path_refs(&action.path, tracker)?;
179 track_set_value_refs(&action.value, tracker)?;
180 all_target_paths.push(resolve_tracked_path(&action.path, tracker));
181 }
182 for path in &expr.remove_actions {
183 track_path_refs(path, tracker)?;
184 all_target_paths.push(resolve_tracked_path(path, tracker));
185 }
186 for action in &expr.add_actions {
187 track_path_refs(&action.path, tracker)?;
188 let val = tracker.resolve_value(&action.value_ref)?;
189 validate_add_type(val)?;
191 all_target_paths.push(resolve_tracked_path(&action.path, tracker));
192 }
193 for action in &expr.delete_actions {
194 track_path_refs(&action.path, tracker)?;
195 let val = tracker.resolve_value(&action.value_ref)?;
196 validate_delete_type(val)?;
198 all_target_paths.push(resolve_tracked_path(&action.path, tracker));
199 }
200
201 for action in &expr.set_actions {
203 validate_set_value_types(&action.value, tracker)?;
204 }
205
206 check_path_overlaps(&all_target_paths)?;
208
209 Ok(())
210}
211
212fn validate_add_type(val: &crate::types::AttributeValue) -> Result<(), String> {
214 use crate::types::AttributeValue;
215 match val {
216 AttributeValue::N(_)
217 | AttributeValue::SS(_)
218 | AttributeValue::NS(_)
219 | AttributeValue::BS(_) => Ok(()),
220 _ => Err(format!(
221 "Invalid UpdateExpression: Incorrect operand type for operator or function; \
222 operator: ADD, operand type: {}",
223 dynamo_type_name(val)
224 )),
225 }
226}
227
228fn validate_delete_type(val: &crate::types::AttributeValue) -> Result<(), String> {
230 use crate::types::AttributeValue;
231 match val {
232 AttributeValue::SS(_) | AttributeValue::NS(_) | AttributeValue::BS(_) => Ok(()),
233 _ => Err(format!(
234 "Invalid UpdateExpression: Incorrect operand type for operator or function; \
235 operator: DELETE, operand type: {}",
236 dynamo_type_name(val)
237 )),
238 }
239}
240
241fn dynamo_type_name(val: &crate::types::AttributeValue) -> &'static str {
243 use crate::types::AttributeValue;
244 match val {
245 AttributeValue::S(_) => "STRING",
246 AttributeValue::N(_) => "NUMBER",
247 AttributeValue::B(_) => "BINARY",
248 AttributeValue::BOOL(_) => "BOOLEAN",
249 AttributeValue::NULL(_) => "NULL",
250 AttributeValue::SS(_) => "SS",
251 AttributeValue::NS(_) => "NS",
252 AttributeValue::BS(_) => "BS",
253 AttributeValue::L(_) => "LIST",
254 AttributeValue::M(_) => "MAP",
255 }
256}
257
258fn validate_set_value_types(
260 value: &SetValue,
261 tracker: &TrackedExpressionAttributes,
262) -> Result<(), String> {
263 match value {
264 SetValue::Operand(op) => validate_set_operand_types(op, tracker),
265 SetValue::Plus(left, right) => {
266 validate_arithmetic_operand(left, "+", tracker)?;
267 validate_arithmetic_operand(right, "+", tracker)
268 }
269 SetValue::Minus(left, right) => {
270 validate_arithmetic_operand(left, "-", tracker)?;
271 validate_arithmetic_operand(right, "-", tracker)
272 }
273 }
274}
275
276fn validate_arithmetic_operand(
278 operand: &SetOperand,
279 op: &str,
280 tracker: &TrackedExpressionAttributes,
281) -> Result<(), String> {
282 use crate::types::AttributeValue;
283 match operand {
284 SetOperand::ValueRef(name) => {
285 let val = tracker.resolve_value(name)?;
286 if !matches!(val, AttributeValue::N(_)) {
287 return Err(format!(
288 "Invalid UpdateExpression: Incorrect operand type for operator or function; \
289 operator or function: {op}, operand type: {}",
290 dynamo_type_name(val)
291 ));
292 }
293 Ok(())
294 }
295 SetOperand::IfNotExists(_, default) => validate_set_operand_types(default, tracker),
296 SetOperand::ListAppend(a, b) => {
297 validate_list_append_operand(a, tracker)?;
298 validate_list_append_operand(b, tracker)
299 }
300 SetOperand::Path(_) => Ok(()), }
302}
303
304fn validate_set_operand_types(
306 operand: &SetOperand,
307 tracker: &TrackedExpressionAttributes,
308) -> Result<(), String> {
309 match operand {
310 SetOperand::ListAppend(a, b) => {
311 validate_list_append_operand(a, tracker)?;
312 validate_list_append_operand(b, tracker)
313 }
314 SetOperand::IfNotExists(_, default) => validate_set_operand_types(default, tracker),
315 _ => Ok(()),
316 }
317}
318
319fn validate_list_append_operand(
321 operand: &SetOperand,
322 tracker: &TrackedExpressionAttributes,
323) -> Result<(), String> {
324 use crate::types::AttributeValue;
325 if let SetOperand::ValueRef(name) = operand {
326 let val = tracker.resolve_value(name)?;
327 if !matches!(val, AttributeValue::L(_)) {
328 return Err(format!(
329 "Invalid UpdateExpression: Incorrect operand type for operator or function; \
330 operator or function: list_append, operand type: {}",
331 dynamo_type_name(val)
332 ));
333 }
334 }
335 Ok(())
336}
337
338fn resolve_tracked_path(
340 path: &[PathElement],
341 tracker: &TrackedExpressionAttributes,
342) -> Vec<PathElement> {
343 path.iter()
344 .map(|elem| {
345 if let PathElement::Attribute(name) = elem {
346 if name.starts_with('#') {
347 if let Ok(resolved) = tracker.resolve_name(name) {
348 return PathElement::Attribute(resolved);
349 }
350 }
351 }
352 elem.clone()
353 })
354 .collect()
355}
356
357fn format_path_for_error(path: &[PathElement]) -> String {
359 let parts: Vec<String> = path
360 .iter()
361 .map(|elem| match elem {
362 PathElement::Attribute(name) => name.clone(),
363 PathElement::Index(i) => format!("[{i}]"),
364 })
365 .collect();
366 format!("[{}]", parts.join(", "))
367}
368
369fn check_path_overlaps(paths: &[Vec<PathElement>]) -> Result<(), String> {
375 for i in 0..paths.len() {
376 for j in (i + 1)..paths.len() {
377 let a = &paths[i];
378 let b = &paths[j];
379 let min_len = a.len().min(b.len());
380
381 let mut common = 0;
383 for k in 0..min_len {
384 if a[k] == b[k] {
385 common += 1;
386 } else {
387 break;
388 }
389 }
390
391 if common == 0 {
392 continue;
393 }
394
395 if common == a.len() || common == b.len() {
397 let (shorter, longer) = if a.len() <= b.len() { (a, b) } else { (b, a) };
398 return Err(format!(
399 "Invalid UpdateExpression: Two document paths overlap with each other; \
400 must remove or rewrite one of these paths; \
401 path one: {}, path two: {}",
402 format_path_for_error(longer),
403 format_path_for_error(shorter)
404 ));
405 }
406
407 if common > 0 && common < min_len && a == b {
409 return Err(format!(
410 "Invalid UpdateExpression: Two document paths conflict with each other; \
411 must remove or rewrite one of these paths; \
412 path one: {}, path two: {}",
413 format_path_for_error(a),
414 format_path_for_error(b)
415 ));
416 }
417 }
418 }
419 Ok(())
420}
421
422fn track_path_refs(
423 path: &[PathElement],
424 tracker: &TrackedExpressionAttributes,
425) -> Result<(), String> {
426 for elem in path {
427 if let PathElement::Attribute(name) = elem {
428 if name.starts_with('#') {
429 tracker.resolve_name(name)?;
430 }
431 }
432 }
433 Ok(())
434}
435
436fn track_set_value_refs(
437 value: &SetValue,
438 tracker: &TrackedExpressionAttributes,
439) -> Result<(), String> {
440 match value {
441 SetValue::Operand(op) => track_set_operand_refs(op, tracker),
442 SetValue::Plus(left, right) | SetValue::Minus(left, right) => {
443 track_set_operand_refs(left, tracker)?;
444 track_set_operand_refs(right, tracker)
445 }
446 }
447}
448
449fn track_set_operand_refs(
450 operand: &SetOperand,
451 tracker: &TrackedExpressionAttributes,
452) -> Result<(), String> {
453 match operand {
454 SetOperand::Path(path) => track_path_refs(path, tracker),
455 SetOperand::ValueRef(name) => {
456 tracker.resolve_value(name)?;
457 Ok(())
458 }
459 SetOperand::IfNotExists(path, default) => {
460 track_path_refs(path, tracker)?;
461 track_set_operand_refs(default, tracker)
462 }
463 SetOperand::ListAppend(a, b) => {
464 track_set_operand_refs(a, tracker)?;
465 track_set_operand_refs(b, tracker)
466 }
467 }
468}
469
470pub fn apply(
472 item: &mut HashMap<String, AttributeValue>,
473 expr: &UpdateExpr,
474 tracker: &TrackedExpressionAttributes,
475) -> Result<(), String> {
476 for action in &expr.set_actions {
478 let resolved_path = resolve_path_elements(&action.path, tracker)?;
479 let value = evaluate_set_value(&action.value, item, tracker)?;
480 set_path(item, &resolved_path, value)?;
481 }
482
483 for path in &expr.remove_actions {
485 let resolved_path = resolve_path_elements(path, tracker)?;
486 remove_path(item, &resolved_path)?;
487 }
488
489 for action in &expr.add_actions {
491 let resolved_path = resolve_path_elements(&action.path, tracker)?;
492 let add_val = tracker.resolve_value(&action.value_ref)?.clone();
493 apply_add(item, &resolved_path, &add_val).map_err(|_| {
494 "An operand in the update expression has an incorrect data type".to_string()
495 })?;
496 }
497
498 for action in &expr.delete_actions {
500 let resolved_path = resolve_path_elements(&action.path, tracker)?;
501 let del_val = tracker.resolve_value(&action.value_ref)?.clone();
502 apply_delete(item, &resolved_path, &del_val).map_err(|_| {
503 "An operand in the update expression has an incorrect data type".to_string()
504 })?;
505 }
506
507 Ok(())
508}
509
510fn evaluate_set_value(
515 value: &SetValue,
516 item: &HashMap<String, AttributeValue>,
517 tracker: &TrackedExpressionAttributes,
518) -> Result<AttributeValue, String> {
519 match value {
520 SetValue::Operand(op) => evaluate_set_operand(op, item, tracker),
521 SetValue::Plus(left, right) => {
522 let lv = evaluate_set_operand(left, item, tracker)?;
523 let rv = evaluate_set_operand(right, item, tracker)?;
524 match (&lv, &rv) {
525 (AttributeValue::N(a), AttributeValue::N(b)) => {
526 use bigdecimal::BigDecimal;
527 use std::str::FromStr;
528 let da = BigDecimal::from_str(a).map_err(|_| format!("Invalid number: {a}"))?;
529 let db = BigDecimal::from_str(b).map_err(|_| format!("Invalid number: {b}"))?;
530 let result = &da + &db;
531 Ok(AttributeValue::N(format_number(&result)))
532 }
533 _ => Err("Operands for + must be numbers".to_string()),
534 }
535 }
536 SetValue::Minus(left, right) => {
537 let lv = evaluate_set_operand(left, item, tracker)?;
538 let rv = evaluate_set_operand(right, item, tracker)?;
539 match (&lv, &rv) {
540 (AttributeValue::N(a), AttributeValue::N(b)) => {
541 use bigdecimal::BigDecimal;
542 use std::str::FromStr;
543 let da = BigDecimal::from_str(a).map_err(|_| format!("Invalid number: {a}"))?;
544 let db = BigDecimal::from_str(b).map_err(|_| format!("Invalid number: {b}"))?;
545 let result = &da - &db;
546 Ok(AttributeValue::N(format_number(&result)))
547 }
548 _ => Err("Operands for - must be numbers".to_string()),
549 }
550 }
551 }
552}
553
554fn evaluate_set_operand(
555 operand: &SetOperand,
556 item: &HashMap<String, AttributeValue>,
557 tracker: &TrackedExpressionAttributes,
558) -> Result<AttributeValue, String> {
559 match operand {
560 SetOperand::Path(path) => {
561 let resolved = resolve_path_elements(path, tracker)?;
562 resolve_path(item, &resolved).ok_or_else(|| {
563 "The provided expression refers to an attribute that does not exist in the item"
564 .to_string()
565 })
566 }
567 SetOperand::ValueRef(name) => Ok(tracker.resolve_value(name)?.clone()),
568 SetOperand::IfNotExists(path, default) => {
569 let resolved = resolve_path_elements(path, tracker)?;
570 match resolve_path(item, &resolved) {
571 Some(existing) => Ok(existing),
572 None => evaluate_set_operand(default, item, tracker),
573 }
574 }
575 SetOperand::ListAppend(list1, list2) => {
576 let v1 = evaluate_set_operand(list1, item, tracker)?;
577 let v2 = evaluate_set_operand(list2, item, tracker)?;
578 match (v1, v2) {
579 (AttributeValue::L(mut a), AttributeValue::L(b)) => {
580 a.extend(b);
581 Ok(AttributeValue::L(a))
582 }
583 _ => Err("list_append requires two list operands".to_string()),
584 }
585 }
586 }
587}
588
589pub fn apply_add_public(
595 item: &mut HashMap<String, AttributeValue>,
596 path: &[PathElement],
597 add_val: &AttributeValue,
598) -> Result<(), String> {
599 apply_add(item, path, add_val)
600}
601
602fn apply_add(
603 item: &mut HashMap<String, AttributeValue>,
604 path: &[PathElement],
605 add_val: &AttributeValue,
606) -> Result<(), String> {
607 let existing = resolve_path(item, path);
608
609 match (existing, add_val) {
610 (Some(AttributeValue::N(existing_n)), AttributeValue::N(add_n)) => {
612 use bigdecimal::BigDecimal;
613 use std::str::FromStr;
614 let de = BigDecimal::from_str(&existing_n)
615 .map_err(|_| format!("Invalid number: {existing_n}"))?;
616 let da = BigDecimal::from_str(add_n).map_err(|_| format!("Invalid number: {add_n}"))?;
617 let result = &de + &da;
618 set_path(item, path, AttributeValue::N(format_number(&result)))
619 }
620 (None, AttributeValue::N(_)) => {
621 set_path(item, path, add_val.clone())
623 }
624
625 (Some(AttributeValue::SS(mut existing_set)), AttributeValue::SS(add_set)) => {
627 for s in add_set {
628 if !existing_set.contains(s) {
629 existing_set.push(s.clone());
630 }
631 }
632 set_path(item, path, AttributeValue::SS(existing_set))
633 }
634 (None, AttributeValue::SS(_)) => set_path(item, path, add_val.clone()),
635
636 (Some(AttributeValue::NS(mut existing_set)), AttributeValue::NS(add_set)) => {
638 for n in add_set {
639 if !existing_set.contains(n) {
640 existing_set.push(n.clone());
641 }
642 }
643 set_path(item, path, AttributeValue::NS(existing_set))
644 }
645 (None, AttributeValue::NS(_)) => set_path(item, path, add_val.clone()),
646
647 (Some(AttributeValue::BS(mut existing_set)), AttributeValue::BS(add_set)) => {
649 for b in add_set {
650 if !existing_set.contains(b) {
651 existing_set.push(b.clone());
652 }
653 }
654 set_path(item, path, AttributeValue::BS(existing_set))
655 }
656 (None, AttributeValue::BS(_)) => set_path(item, path, add_val.clone()),
657
658 (Some(AttributeValue::L(mut existing_list)), AttributeValue::L(add_list)) => {
660 existing_list.extend(add_list.iter().cloned());
661 set_path(item, path, AttributeValue::L(existing_list))
662 }
663 (None, AttributeValue::L(_)) => set_path(item, path, add_val.clone()),
664
665 _ => Err("Type mismatch for attribute to update".to_string()),
666 }
667}
668
669pub fn apply_delete_public(
675 item: &mut HashMap<String, AttributeValue>,
676 path: &[PathElement],
677 del_val: &AttributeValue,
678) -> Result<(), String> {
679 apply_delete(item, path, del_val)
680}
681
682fn apply_delete(
683 item: &mut HashMap<String, AttributeValue>,
684 path: &[PathElement],
685 del_val: &AttributeValue,
686) -> Result<(), String> {
687 let existing = resolve_path(item, path);
688
689 match (existing, del_val) {
690 (Some(AttributeValue::SS(existing_set)), AttributeValue::SS(del_set)) => {
691 let new_set: Vec<String> = existing_set
692 .into_iter()
693 .filter(|s| !del_set.contains(s))
694 .collect();
695 if new_set.is_empty() {
696 remove_path(item, path)
697 } else {
698 set_path(item, path, AttributeValue::SS(new_set))
699 }
700 }
701 (Some(AttributeValue::NS(existing_set)), AttributeValue::NS(del_set)) => {
702 let new_set: Vec<String> = existing_set
703 .into_iter()
704 .filter(|n| !del_set.contains(n))
705 .collect();
706 if new_set.is_empty() {
707 remove_path(item, path)
708 } else {
709 set_path(item, path, AttributeValue::NS(new_set))
710 }
711 }
712 (Some(AttributeValue::BS(existing_set)), AttributeValue::BS(del_set)) => {
713 let new_set: Vec<Vec<u8>> = existing_set
714 .into_iter()
715 .filter(|b| !del_set.contains(b))
716 .collect();
717 if new_set.is_empty() {
718 remove_path(item, path)
719 } else {
720 set_path(item, path, AttributeValue::BS(new_set))
721 }
722 }
723 (None, _) => Ok(()), _ => Err("Type mismatch for attribute to update".to_string()),
725 }
726}
727
728fn parse_set_clause(stream: &mut TokenStream, actions: &mut Vec<SetAction>) -> Result<(), String> {
733 actions.push(parse_set_action(stream)?);
734 while matches!(stream.peek(), Some(Token::Comma)) {
735 stream.next();
736 actions.push(parse_set_action(stream)?);
737 }
738 Ok(())
739}
740
741fn parse_set_action(stream: &mut TokenStream) -> Result<SetAction, String> {
742 let path = parse_raw_path(stream)?;
743 stream.expect(&Token::Eq)?;
744 let value = parse_set_value(stream)?;
745 Ok(SetAction { path, value })
746}
747
748fn parse_set_value(stream: &mut TokenStream) -> Result<SetValue, String> {
749 let left = parse_set_operand(stream)?;
750
751 match stream.peek() {
752 Some(Token::Plus) => {
753 stream.next();
754 let right = parse_set_operand(stream)?;
755 Ok(SetValue::Plus(left, right))
756 }
757 Some(Token::Minus) => {
758 stream.next();
759 let right = parse_set_operand(stream)?;
760 Ok(SetValue::Minus(left, right))
761 }
762 _ => Ok(SetValue::Operand(left)),
763 }
764}
765
766fn parse_set_operand(stream: &mut TokenStream) -> Result<SetOperand, String> {
767 if let Some(Token::Identifier(name)) = stream.peek() {
769 let func_name = name.to_lowercase();
770 let orig_name = name.clone();
771 match func_name.as_str() {
772 "if_not_exists" => {
773 stream.next();
774 stream.expect(&Token::LParen)?;
775
776 match stream.peek() {
778 Some(Token::ValueRef(_)) => {
779 return Err(
780 "Invalid UpdateExpression: Operator or function requires a document path; \
781 operator or function: if_not_exists".to_string()
782 );
783 }
784 Some(Token::Identifier(fname))
785 if fname.to_lowercase() == "if_not_exists"
786 || fname.to_lowercase() == "list_append" =>
787 {
788 return Err(
789 "Invalid UpdateExpression: Operator or function requires a document path; \
790 operator or function: if_not_exists".to_string()
791 );
792 }
793 _ => {}
794 }
795
796 let path = parse_raw_path(stream)?;
797
798 if !matches!(stream.peek(), Some(Token::Comma)) {
800 return Err(
801 "Invalid UpdateExpression: Incorrect number of operands for operator or function; \
802 operator or function: if_not_exists, number of operands: 1".to_string()
803 );
804 }
805 stream.expect(&Token::Comma)?;
806 let default = parse_set_operand(stream)?;
807 stream.expect(&Token::RParen)?;
808 return Ok(SetOperand::IfNotExists(path, Box::new(default)));
809 }
810 "list_append" => {
811 stream.next();
812 stream.expect(&Token::LParen)?;
813 let list1 = parse_set_operand(stream)?;
814
815 if !matches!(stream.peek(), Some(Token::Comma)) {
817 return Err(
818 "Invalid UpdateExpression: Incorrect number of operands for operator or function; \
819 operator or function: list_append, number of operands: 1".to_string()
820 );
821 }
822 stream.expect(&Token::Comma)?;
823 let list2 = parse_set_operand(stream)?;
824 stream.expect(&Token::RParen)?;
825 return Ok(SetOperand::ListAppend(Box::new(list1), Box::new(list2)));
826 }
827 _ => {
828 let saved_pos = stream.pos();
831 stream.next();
832 if matches!(stream.peek(), Some(Token::LParen)) {
833 return Err(format!(
834 "Invalid UpdateExpression: Invalid function name; function: {}",
835 orig_name
836 ));
837 }
838 stream.set_pos(saved_pos);
840 }
841 }
842 }
843
844 match stream.peek() {
845 Some(Token::ValueRef(_)) => {
846 if let Some(Token::ValueRef(name)) = stream.next().cloned() {
847 Ok(SetOperand::ValueRef(name))
848 } else {
849 unreachable!()
850 }
851 }
852 Some(Token::Identifier(_)) | Some(Token::NameRef(_)) => {
853 let path = parse_raw_path(stream)?;
854 Ok(SetOperand::Path(path))
855 }
856 Some(t) => Err(format!("Expected operand in SET, got {t}")),
857 None => Err("Expected operand in SET, got end of expression".to_string()),
858 }
859}
860
861fn parse_remove_clause(
862 stream: &mut TokenStream,
863 actions: &mut Vec<Vec<PathElement>>,
864) -> Result<(), String> {
865 actions.push(parse_raw_path(stream)?);
866 while matches!(stream.peek(), Some(Token::Comma)) {
867 stream.next();
868 actions.push(parse_raw_path(stream)?);
869 }
870 Ok(())
871}
872
873fn parse_add_clause(stream: &mut TokenStream, actions: &mut Vec<AddAction>) -> Result<(), String> {
874 actions.push(parse_add_action(stream)?);
875 while matches!(stream.peek(), Some(Token::Comma)) {
876 stream.next();
877 actions.push(parse_add_action(stream)?);
878 }
879 Ok(())
880}
881
882fn parse_add_action(stream: &mut TokenStream) -> Result<AddAction, String> {
883 let path = parse_raw_path(stream)?;
884 match stream.next() {
885 Some(Token::ValueRef(name)) => Ok(AddAction {
886 path,
887 value_ref: name.clone(),
888 }),
889 Some(t) => Err(format!("Expected value reference in ADD, got {t}")),
890 None => Err("Expected value reference in ADD, got end of expression".to_string()),
891 }
892}
893
894fn parse_delete_clause(
895 stream: &mut TokenStream,
896 actions: &mut Vec<DeleteAction>,
897) -> Result<(), String> {
898 actions.push(parse_delete_action(stream)?);
899 while matches!(stream.peek(), Some(Token::Comma)) {
900 stream.next();
901 actions.push(parse_delete_action(stream)?);
902 }
903 Ok(())
904}
905
906fn parse_delete_action(stream: &mut TokenStream) -> Result<DeleteAction, String> {
907 let path = parse_raw_path(stream)?;
908 match stream.next() {
909 Some(Token::ValueRef(name)) => Ok(DeleteAction {
910 path,
911 value_ref: name.clone(),
912 }),
913 Some(t) => Err(format!("Expected value reference in DELETE, got {t}")),
914 None => Err("Expected value reference in DELETE, got end of expression".to_string()),
915 }
916}
917
918fn format_number(n: &bigdecimal::BigDecimal) -> String {
921 let normalized = n.normalized();
922 if normalized.as_bigint_and_exponent().1 < 0 {
926 normalized.with_scale(0).to_string()
927 } else {
928 normalized.to_string()
929 }
930}
931
932#[cfg(test)]
933mod tests {
934 use super::*;
935
936 fn make_item(pairs: &[(&str, AttributeValue)]) -> HashMap<String, AttributeValue> {
937 pairs
938 .iter()
939 .map(|(k, v)| (k.to_string(), v.clone()))
940 .collect()
941 }
942
943 fn vals(pairs: &[(&str, AttributeValue)]) -> Option<HashMap<String, AttributeValue>> {
944 Some(make_item(pairs))
945 }
946
947 fn make_tracker<'a>(
948 names: &'a Option<HashMap<String, String>>,
949 values: &'a Option<HashMap<String, AttributeValue>>,
950 ) -> TrackedExpressionAttributes<'a> {
951 TrackedExpressionAttributes::new(names, values)
952 }
953
954 #[test]
955 fn test_set_simple() {
956 let expr = parse("SET label = :val").unwrap();
957 assert_eq!(expr.set_actions.len(), 1);
958 assert!(expr.remove_actions.is_empty());
959 }
960
961 #[test]
962 fn test_set_multiple() {
963 let expr = parse("SET a = :v1, b = :v2").unwrap();
964 assert_eq!(expr.set_actions.len(), 2);
965 }
966
967 #[test]
968 fn test_set_arithmetic_plus() {
969 let expr = parse("SET tally = tally + :inc").unwrap();
970 let mut item = make_item(&[
971 ("pk", AttributeValue::S("k".into())),
972 ("tally", AttributeValue::N("10".into())),
973 ]);
974 let av = vals(&[(":inc", AttributeValue::N("5".into()))]);
975 let no_names = None;
976 let tracker = make_tracker(&no_names, &av);
977 apply(&mut item, &expr, &tracker).unwrap();
978 assert_eq!(item["tally"], AttributeValue::N("15".into()));
979 }
980
981 #[test]
982 fn test_set_arithmetic_minus() {
983 let expr = parse("SET price = price - :discount").unwrap();
984 let mut item = make_item(&[
985 ("pk", AttributeValue::S("k".into())),
986 ("price", AttributeValue::N("100".into())),
987 ]);
988 let av = vals(&[(":discount", AttributeValue::N("25".into()))]);
989 let no_names = None;
990 let tracker = make_tracker(&no_names, &av);
991 apply(&mut item, &expr, &tracker).unwrap();
992 assert_eq!(item["price"], AttributeValue::N("75".into()));
993 }
994
995 #[test]
996 fn test_set_if_not_exists() {
997 let expr = parse("SET hits = if_not_exists(hits, :zero)").unwrap();
998 let mut item = make_item(&[("pk", AttributeValue::S("k".into()))]);
999 let av = vals(&[(":zero", AttributeValue::N("0".into()))]);
1000 let no_names = None;
1001 let tracker = make_tracker(&no_names, &av);
1002 apply(&mut item, &expr, &tracker).unwrap();
1003 assert_eq!(item["hits"], AttributeValue::N("0".into()));
1004
1005 let tracker2 = make_tracker(&no_names, &av);
1007 apply(&mut item, &expr, &tracker2).unwrap();
1008 assert_eq!(item["hits"], AttributeValue::N("0".into()));
1009 }
1010
1011 #[test]
1012 fn test_set_list_append() {
1013 let expr = parse("SET entries = list_append(entries, :new)").unwrap();
1014 let mut item = make_item(&[
1015 ("pk", AttributeValue::S("k".into())),
1016 (
1017 "entries",
1018 AttributeValue::L(vec![AttributeValue::S("a".into())]),
1019 ),
1020 ]);
1021 let av = vals(&[(
1022 ":new",
1023 AttributeValue::L(vec![AttributeValue::S("b".into())]),
1024 )]);
1025 let no_names = None;
1026 let tracker = make_tracker(&no_names, &av);
1027 apply(&mut item, &expr, &tracker).unwrap();
1028 if let AttributeValue::L(list) = &item["entries"] {
1029 assert_eq!(list.len(), 2);
1030 } else {
1031 panic!("Expected list");
1032 }
1033 }
1034
1035 #[test]
1036 fn test_remove() {
1037 let expr = parse("REMOVE attr1, attr2").unwrap();
1038 let mut item = make_item(&[
1039 ("pk", AttributeValue::S("k".into())),
1040 ("attr1", AttributeValue::S("a".into())),
1041 ("attr2", AttributeValue::S("b".into())),
1042 ("attr3", AttributeValue::S("c".into())),
1043 ]);
1044 let no_names = None;
1045 let no_values = None;
1046 let tracker = make_tracker(&no_names, &no_values);
1047 apply(&mut item, &expr, &tracker).unwrap();
1048 assert!(!item.contains_key("attr1"));
1049 assert!(!item.contains_key("attr2"));
1050 assert!(item.contains_key("attr3"));
1051 }
1052
1053 #[test]
1054 fn test_add_number() {
1055 let expr = parse("ADD tally :inc").unwrap();
1056 let mut item = make_item(&[
1057 ("pk", AttributeValue::S("k".into())),
1058 ("tally", AttributeValue::N("10".into())),
1059 ]);
1060 let av = vals(&[(":inc", AttributeValue::N("5".into()))]);
1061 let no_names = None;
1062 let tracker = make_tracker(&no_names, &av);
1063 apply(&mut item, &expr, &tracker).unwrap();
1064 assert_eq!(item["tally"], AttributeValue::N("15".into()));
1065 }
1066
1067 #[test]
1068 fn test_add_number_create() {
1069 let expr = parse("ADD tally :val").unwrap();
1070 let mut item = make_item(&[("pk", AttributeValue::S("k".into()))]);
1071 let av = vals(&[(":val", AttributeValue::N("1".into()))]);
1072 let no_names = None;
1073 let tracker = make_tracker(&no_names, &av);
1074 apply(&mut item, &expr, &tracker).unwrap();
1075 assert_eq!(item["tally"], AttributeValue::N("1".into()));
1076 }
1077
1078 #[test]
1079 fn test_add_string_set() {
1080 let expr = parse("ADD colors :new_colors").unwrap();
1081 let mut item = make_item(&[
1082 ("pk", AttributeValue::S("k".into())),
1083 (
1084 "colors",
1085 AttributeValue::SS(vec!["red".into(), "blue".into()]),
1086 ),
1087 ]);
1088 let av = vals(&[(
1089 ":new_colors",
1090 AttributeValue::SS(vec!["blue".into(), "green".into()]),
1091 )]);
1092 let no_names = None;
1093 let tracker = make_tracker(&no_names, &av);
1094 apply(&mut item, &expr, &tracker).unwrap();
1095 if let AttributeValue::SS(set) = &item["colors"] {
1096 assert_eq!(set.len(), 3); assert!(set.contains(&"green".to_string()));
1098 } else {
1099 panic!("Expected SS");
1100 }
1101 }
1102
1103 #[test]
1104 fn test_delete_string_set() {
1105 let expr = parse("DELETE colors :remove").unwrap();
1106 let mut item = make_item(&[
1107 ("pk", AttributeValue::S("k".into())),
1108 (
1109 "colors",
1110 AttributeValue::SS(vec!["red".into(), "blue".into(), "green".into()]),
1111 ),
1112 ]);
1113 let av = vals(&[(
1114 ":remove",
1115 AttributeValue::SS(vec!["blue".into(), "green".into()]),
1116 )]);
1117 let no_names = None;
1118 let tracker = make_tracker(&no_names, &av);
1119 apply(&mut item, &expr, &tracker).unwrap();
1120 if let AttributeValue::SS(set) = &item["colors"] {
1121 assert_eq!(set, &vec!["red".to_string()]);
1122 } else {
1123 panic!("Expected SS");
1124 }
1125 }
1126
1127 #[test]
1128 fn test_combined_set_remove() {
1129 let expr = parse("SET label = :name REMOVE old_attr").unwrap();
1130 assert_eq!(expr.set_actions.len(), 1);
1131 assert_eq!(expr.remove_actions.len(), 1);
1132 }
1133
1134 #[test]
1135 fn test_duplicate_clause_error() {
1136 let result = parse("SET a = :v SET b = :w");
1137 assert!(result.is_err());
1138 assert!(result.unwrap_err().contains("only be used once"));
1139 }
1140}